"""Node builders for ANTs registration workflows.
This module provides node builders for ANTs-based image registration including:
- ANTs Registration (Rigid, Affine, SyN)
- ANTs ApplyTransforms (forward and inverse)
"""
from nipype import Node
from nipype.interfaces.ants import ApplyTransforms, Registration
from thesis.core.logging import get_logger
__all__ = [
"prepare_ants_registration_node",
"make_ants_apply_transforms_interface",
"prepare_ants_apply_transforms_node",
]
logger = get_logger(__name__)
[docs]
def prepare_ants_registration_node(
transform_type: str = "Rigid",
num_threads: int = 1,
name: str = "registration",
) -> Node:
"""Create ANTs Registration node with appropriate settings for the transform type.
Args:
transform_type: Type of registration transform. Options: "Rigid", "Affine", "SyN"
num_threads: Number of parallel threads for ANTs (default: 1)
name: Node name (default: "registration")
Returns:
Configured Nipype Node with ANTs Registration interface
Raises:
ValueError: If transform_type is not one of "Rigid", "Affine", "SyN"
Example:
>>> node = prepare_ants_registration_node(transform_type="Rigid", num_threads=4)
>>> node.inputs.moving_image = "T1.nii.gz"
>>> node.inputs.fixed_image = "DWI_b0.nii.gz"
"""
# Each transform_type expands to an ordered list of registration stages.
# SyN registration is preceded by a Rigid + Affine warm-up.
stages_by_type: dict[str, list[str]] = {
"Rigid": ["Rigid"],
"Affine": ["Rigid", "Affine"],
"SyN": ["Rigid", "Affine", "SyN"],
}
if transform_type not in stages_by_type:
raise ValueError(
f"Invalid transform_type: {transform_type}. Must be one of {list(stages_by_type)}"
)
logger.debug(f"Creating ANTs Registration node: {name} (type: {transform_type})")
stages = stages_by_type[transform_type]
# Per-stage gradient-step parameters and iteration schedule. Linear stages
# (Rigid/Affine) share one gradient step + iteration schedule; the SyN stage
# uses (gradient step, flow sigma, total sigma) and a shorter (more
# expensive) schedule.
stage_params: dict[str, tuple[tuple[float, ...], list[int]]] = {
"Rigid": ((0.1,), [1000, 500, 250, 100]),
"Affine": ((0.1,), [1000, 500, 250, 100]),
"SyN": ((0.25, 3.0, 0.0), [100, 70, 50, 20]),
}
# Create the registration node
reg = Node(Registration(), name=name)
# Common parameters for all registration types
reg.inputs.dimension = 3
reg.inputs.winsorize_lower_quantile = 0.005
reg.inputs.winsorize_upper_quantile = 0.995
reg.inputs.use_histogram_matching = False # Important for T1->DWI registration
reg.inputs.write_composite_transform = True
reg.inputs.verbose = True
reg.inputs.num_threads = num_threads
# 1 = center-of-mass init (antsRegistration -r [fixed,moving,1]). 0 means
# identity init, which leaves Rigid+MI stuck in the wrong basin when T1 and
# DWI aren't pre-aligned and silently produces a near-identity transform.
reg.inputs.initial_moving_transform_com = 1
# Per-stage settings. Metric (MI), shrink factors, smoothing sigmas, weight,
# bins, sampling and convergence are identical across every stage.
reg.inputs.transforms = list(stages)
reg.inputs.transform_parameters = [stage_params[s][0] for s in stages]
reg.inputs.number_of_iterations = [stage_params[s][1] for s in stages]
reg.inputs.metric = ["MI" for _ in stages]
reg.inputs.shrink_factors = [[8, 4, 2, 1] for _ in stages]
reg.inputs.smoothing_sigmas = [[3, 2, 1, 0] for _ in stages]
reg.inputs.metric_weight = [1.0 for _ in stages]
reg.inputs.radius_or_number_of_bins = [32 for _ in stages]
reg.inputs.sampling_strategy = ["Regular" for _ in stages]
reg.inputs.sampling_percentage = [0.25 for _ in stages]
reg.inputs.convergence_threshold = [1e-6 for _ in stages]
reg.inputs.convergence_window_size = [10 for _ in stages]
logger.debug(f"ANTs Registration node created: {name}")
return reg