Source code for thesis.workflows.hcp.nodes.roi_mapnode

"""MapNode/JoinNode factories for the atlas ROI pipeline.

These replace the per-source ``prepare_roi_extractor``/``prepare_roi_transformer``/
``prepare_roi_validator``/``prepare_roi_merger`` factory pattern used by the
legacy HCP workflow. The new design uses Nipype's native ``MapNode`` to iterate
over atlas sources and a ``JoinNode`` to merge results — eliminating the
per-source Python loop and the binary merger reduce tree.
"""

from pathlib import Path
from typing import Optional

from nipype import MapNode, Node
from nipype.interfaces.utility import Function

from thesis.core.config import PipelineConfig
from thesis.core.context import ProcessingContext

from ..common import (
    ValidationThresholds,
    format_patient_path,
    get_atlas_sources,
    get_atlas_transform_spec,
    get_trial_root,
    resolve_with_fallback,
)
from ..operations import (
    extract_rois_task,
    merge_rois_task,
    transform_rois_task,
    validate_warped_rois_passthrough_task,
)

_ROI_FIELDS = ("seed", "waypoints_file", "stop_mask", "avoid_mask", "target_mask")


[docs] def prepare_atlas_extract_mapnode( config: PipelineConfig, context: ProcessingContext, out_dir: Path, name: str = "atlas_extract", ) -> Optional[MapNode]: """Build a ``MapNode`` that runs ``extract_rois_task`` once per atlas source. Returns ``None`` if the config declares no atlas sources. """ atlas_sources = get_atlas_sources(config) if not atlas_sources: return None input_dir = Path(context.input_dir).resolve() if context.input_dir else Path(".").resolve() roi_files: list[str] = [] label_files: list[str] = [] waypoint_labels: list[dict] = [] output_dirs: list[str] = [] for index, source in enumerate(atlas_sources, start=1): source_name = str(source.get("name", f"atlas_{index}")) roi_files.append( str(resolve_with_fallback(source.get("roi_file", ""), input_dir, [context.data_dir])) ) label_file_raw = source.get("label_file", "") label_files.append( str(resolve_with_fallback(label_file_raw, input_dir, [context.data_dir])) if label_file_raw else "" ) waypoint_labels.append(source.get("waypoint_labels", {})) output_dirs.append(str(out_dir / "rois" / source_name)) mapnode = MapNode( Function( input_names=[ "roi_file", "label_file", "waypoint_labels", "output_dir", "hemisphere", ], output_names=list(_ROI_FIELDS), function=extract_rois_task, ), iterfield=["roi_file", "label_file", "waypoint_labels", "output_dir"], name=name, ) mapnode.inputs.roi_file = roi_files mapnode.inputs.label_file = label_files mapnode.inputs.waypoint_labels = waypoint_labels mapnode.inputs.output_dir = output_dirs mapnode.inputs.hemisphere = "both" return mapnode
[docs] def prepare_atlas_transform_mapnode( config: PipelineConfig, context: ProcessingContext, out_dir: Path, name: str = "atlas_transform", ) -> Optional[MapNode]: """Build a ``MapNode`` running ``transform_rois_task`` per atlas source. Returns ``None`` if no atlas sources have a transform configured (in which case the upstream extract MapNode connects directly to the validate MapNode). Sources whose transform list is empty still appear as iterations, with empty ``warp_field`` causing the task to no-op-passthrough. """ atlas_sources = get_atlas_sources(config) if not atlas_sources: return None input_dir = Path(context.input_dir).resolve() if context.input_dir else Path(".").resolve() # template_to_patient / reference_image typically point at the current run's # own registration outputs (e.g. outputs/{pid}/registration/transforms/...), # so anchor relative paths at the trial root. input_dir / data_dir remain as # fallbacks for precomputed transforms that already exist there (standalone). trial_root = get_trial_root(context) patient_id = getattr(context, "patient_id", "") warp_fields: list = [] reference_images: list[str] = [] output_dirs: list[str] = [] any_transform = False for index, source in enumerate(atlas_sources, start=1): source_name = str(source.get("name", f"atlas_{index}")) spec = get_atlas_transform_spec(config, source) or {} warp_raw = spec.get("template_to_patient", "") ref_raw = spec.get("reference_image", "") if not (warp_raw and ref_raw): warp_fields.append([]) reference_images.append("") output_dirs.append(str(out_dir / "rois_transformed" / source_name)) continue any_transform = True if isinstance(warp_raw, list): warp_resolved = [ str( resolve_with_fallback( format_patient_path(value, patient_id), trial_root, [input_dir, context.data_dir], ) ) for value in warp_raw ] else: warp_resolved = [ str( resolve_with_fallback( format_patient_path(warp_raw, patient_id), trial_root, [input_dir, context.data_dir], ) ) ] ref_resolved = str( resolve_with_fallback( format_patient_path(ref_raw, patient_id), trial_root, [input_dir, context.data_dir], ) ) warp_fields.append(warp_resolved) reference_images.append(ref_resolved) output_dirs.append(str(out_dir / "rois_transformed" / source_name)) if not any_transform: return None mapnode = MapNode( Function( input_names=[ "seed", "waypoints_file", "stop_mask", "avoid_mask", "target_mask", "warp_field", "reference_image", "output_dir", "hemisphere", "_ordering_signal", ], output_names=list(_ROI_FIELDS), function=transform_rois_task, ), iterfield=[ "seed", "waypoints_file", "stop_mask", "avoid_mask", "target_mask", "warp_field", "reference_image", "output_dir", ], name=name, ) mapnode.inputs.warp_field = warp_fields mapnode.inputs.reference_image = reference_images mapnode.inputs.output_dir = output_dirs # Default — overridden via wf.connect(inputs.hemisphere -> .hemisphere) so # each iteration of the outer hemisphere iterable scopes its outputs into # <out_dir>/rois_transformed/<source>/<hemisphere>/. mapnode.inputs.hemisphere = "both" return mapnode
[docs] def prepare_atlas_validate_mapnode( thresholds: ValidationThresholds, reference_image: str, out_dir: Path, name: str = "atlas_validate", ) -> MapNode: """Build a ``MapNode`` that validates each per-source ROI bundle.""" mapnode = MapNode( Function( input_names=[ "seed", "waypoints_file", "stop_mask", "avoid_mask", "target_mask", "reference_image", "min_voxels", "singularity_threshold", "volume_ratio_min", "volume_ratio_max", ], output_names=list(_ROI_FIELDS), function=validate_warped_rois_passthrough_task, ), iterfield=list(_ROI_FIELDS), name=name, ) mapnode.inputs.reference_image = reference_image mapnode.inputs.min_voxels = thresholds["min_voxels"] mapnode.inputs.singularity_threshold = thresholds["singularity_threshold"] mapnode.inputs.volume_ratio_min = thresholds["volume_ratio_min"] mapnode.inputs.volume_ratio_max = thresholds["volume_ratio_max"] return mapnode
[docs] def prepare_atlas_merge_joinnode( joinsource_name: str, out_dir: Path, name: str = "atlas_merge", ) -> Node: """Build a ``Node`` that collects per-source ROI bundles into one bundle. Connected from an upstream ``MapNode`` (``atlas_validate``) whose per-source outputs are aggregated into lists by the MapNode itself; the regular ``Node`` receives those lists directly via Nipype's standard connection semantics. A ``JoinNode`` was used here originally, but under outer ``iterables`` on hemisphere (``--hemisphere both-separately``) Nipype's join machinery collected zero iterations from the MapNode joinsource and ``merge_rois_task`` was invoked with empty lists for every field — producing all-empty outputs and a downstream ``probtrackx2.seed=''`` crash. ``joinsource_name`` is kept in the signature for backward compatibility with callers and is unused. """ del joinsource_name # kept for backward compatibility with workflow callers node = Node( Function( input_names=[ "seeds", "waypoints_files", "stop_masks", "avoid_masks", "target_masks", "output_dir", "hemisphere", ], output_names=list(_ROI_FIELDS), function=merge_rois_task, ), name=name, ) node.inputs.output_dir = str(out_dir / "rois_merged") node.inputs.hemisphere = "both" return node
[docs] def prepare_final_merger(out_dir: Path, name: str = "final_merger") -> Node: """Build a binary 2-input merger combining atlas and SynthSeg bundles.""" from ..operations.merging import binary_merge_rois_task node = Node( Function( input_names=[ "left_seed", "left_waypoints_file", "left_stop", "left_avoid", "left_target", "right_seed", "right_waypoints_file", "right_stop", "right_avoid", "right_target", "output_dir", "hemisphere", ], output_names=list(_ROI_FIELDS), function=binary_merge_rois_task, ), name=name, ) node.inputs.output_dir = str(out_dir / "rois_final") node.inputs.hemisphere = "both" return node