Source code for thesis.workflows.preprocess.nodes.registration

"""Node builders for ANTs registration workflows.

This module provides node builders for ANTs-based image registration including:
- ANTs Registration (Rigid, Affine, SyN)
- ANTs ApplyTransforms (forward and inverse)
"""

from nipype import Node
from nipype.interfaces.ants import ApplyTransforms, Registration

from thesis.core.logging import get_logger

__all__ = [
    "prepare_ants_registration_node",
    "make_ants_apply_transforms_interface",
    "prepare_ants_apply_transforms_node",
]

logger = get_logger(__name__)


[docs] def prepare_ants_registration_node( transform_type: str = "Rigid", num_threads: int = 1, name: str = "registration", ) -> Node: """Create ANTs Registration node with appropriate settings for the transform type. Args: transform_type: Type of registration transform. Options: "Rigid", "Affine", "SyN" num_threads: Number of parallel threads for ANTs (default: 1) name: Node name (default: "registration") Returns: Configured Nipype Node with ANTs Registration interface Raises: ValueError: If transform_type is not one of "Rigid", "Affine", "SyN" Example: >>> node = prepare_ants_registration_node(transform_type="Rigid", num_threads=4) >>> node.inputs.moving_image = "T1.nii.gz" >>> node.inputs.fixed_image = "DWI_b0.nii.gz" """ # Each transform_type expands to an ordered list of registration stages. # SyN registration is preceded by a Rigid + Affine warm-up. stages_by_type: dict[str, list[str]] = { "Rigid": ["Rigid"], "Affine": ["Rigid", "Affine"], "SyN": ["Rigid", "Affine", "SyN"], } if transform_type not in stages_by_type: raise ValueError( f"Invalid transform_type: {transform_type}. Must be one of {list(stages_by_type)}" ) logger.debug(f"Creating ANTs Registration node: {name} (type: {transform_type})") stages = stages_by_type[transform_type] # Per-stage gradient-step parameters and iteration schedule. Linear stages # (Rigid/Affine) share one gradient step + iteration schedule; the SyN stage # uses (gradient step, flow sigma, total sigma) and a shorter (more # expensive) schedule. stage_params: dict[str, tuple[tuple[float, ...], list[int]]] = { "Rigid": ((0.1,), [1000, 500, 250, 100]), "Affine": ((0.1,), [1000, 500, 250, 100]), "SyN": ((0.25, 3.0, 0.0), [100, 70, 50, 20]), } # Create the registration node reg = Node(Registration(), name=name) # Common parameters for all registration types reg.inputs.dimension = 3 reg.inputs.winsorize_lower_quantile = 0.005 reg.inputs.winsorize_upper_quantile = 0.995 reg.inputs.use_histogram_matching = False # Important for T1->DWI registration reg.inputs.write_composite_transform = True reg.inputs.verbose = True reg.inputs.num_threads = num_threads # 1 = center-of-mass init (antsRegistration -r [fixed,moving,1]). 0 means # identity init, which leaves Rigid+MI stuck in the wrong basin when T1 and # DWI aren't pre-aligned and silently produces a near-identity transform. reg.inputs.initial_moving_transform_com = 1 # Per-stage settings. Metric (MI), shrink factors, smoothing sigmas, weight, # bins, sampling and convergence are identical across every stage. reg.inputs.transforms = list(stages) reg.inputs.transform_parameters = [stage_params[s][0] for s in stages] reg.inputs.number_of_iterations = [stage_params[s][1] for s in stages] reg.inputs.metric = ["MI" for _ in stages] reg.inputs.shrink_factors = [[8, 4, 2, 1] for _ in stages] reg.inputs.smoothing_sigmas = [[3, 2, 1, 0] for _ in stages] reg.inputs.metric_weight = [1.0 for _ in stages] reg.inputs.radius_or_number_of_bins = [32 for _ in stages] reg.inputs.sampling_strategy = ["Regular" for _ in stages] reg.inputs.sampling_percentage = [0.25 for _ in stages] reg.inputs.convergence_threshold = [1e-6 for _ in stages] reg.inputs.convergence_window_size = [10 for _ in stages] logger.debug(f"ANTs Registration node created: {name}") return reg
[docs] def make_ants_apply_transforms_interface( dimension: int = 3, interpolation: str = "BSpline", ) -> ApplyTransforms: """Build a configured ANTs ApplyTransforms interface. Returns the bare interface (not wrapped in a Node) so callers can plug it into a ``Node`` or a ``MapNode`` without first constructing a throwaway Node and reaching into its ``.interface``. Args: dimension: Image dimension (2 or 3, default: 3) interpolation: Interpolation method. Options: "BSpline", "Linear", "NearestNeighbor", "MultiLabel" (default: "BSpline") Returns: Configured ANTs ApplyTransforms interface Raises: ValueError: If dimension is not 2 or 3 ValueError: If interpolation is not valid """ valid_dimensions = [2, 3] if dimension not in valid_dimensions: raise ValueError(f"Invalid dimension: {dimension}. Must be one of {valid_dimensions}") valid_interpolations = ["BSpline", "Linear", "NearestNeighbor", "MultiLabel"] if interpolation not in valid_interpolations: raise ValueError( f"Invalid interpolation: {interpolation}. Must be one of {valid_interpolations}" ) interface = ApplyTransforms() interface.inputs.dimension = dimension interface.inputs.interpolation = interpolation interface.inputs.float = True # Use float precision return interface
[docs] def prepare_ants_apply_transforms_node( dimension: int = 3, interpolation: str = "BSpline", name: str = "apply_transforms", ) -> Node: """Create ANTs ApplyTransforms node for applying transformations to images. Args: dimension: Image dimension (2 or 3, default: 3) interpolation: Interpolation method. Options: "BSpline", "Linear", "NearestNeighbor", "MultiLabel" (default: "BSpline") name: Node name (default: "apply_transforms") Returns: Configured Nipype Node with ANTs ApplyTransforms interface Raises: ValueError: If dimension is not 2 or 3 ValueError: If interpolation is not valid Example: >>> node = prepare_ants_apply_transforms_node(interpolation="Linear") >>> node.inputs.input_image = "T1.nii.gz" >>> node.inputs.reference_image = "DWI_b0.nii.gz" >>> node.inputs.transforms = ["0GenericAffine.mat"] """ logger.debug(f"Creating ANTs ApplyTransforms node: {name} (interpolation: {interpolation})") apply_tx = Node( make_ants_apply_transforms_interface(dimension=dimension, interpolation=interpolation), name=name, ) logger.debug(f"ANTs ApplyTransforms node created: {name}") return apply_tx