Source code for thesis.workflows.preprocess.nodes.structural

"""Node builders for structural preprocessing workflows.

This module provides node builders for structural MRI preprocessing including:
- N4 bias field correction (ANTs)
- Robust field-of-view cropping (FSL)
- Image thresholding (FSL)
- SynthSeg segmentation (FreeSurfer)
- Tissue type conversion (GM/WM/CSF)
- Tractography mask creation
"""

from nipype import Node
from nipype.interfaces.ants import N4BiasFieldCorrection
from nipype.interfaces.fsl import RobustFOV, Threshold
from nipype.interfaces.utility import Function

from thesis.core.logging import get_logger
from thesis.core.nipype.interfaces.freesurfer import SynthSeg

__all__ = [
    "prepare_n4_bias_correction_node",
    "prepare_robust_fov_node",
    "prepare_threshold_node",
    "prepare_synthseg_node",
    "prepare_gm_wm_csf_converter_node",
    "prepare_tract_masks_creator_node",
]

logger = get_logger(__name__)


def _convert_gm_wm_csf_wrapper(input_image: str, output_image: str) -> str:
    """Wrapper for convert_synthseg_to_gm_wm_csf to use in Function node.

    Args:
        input_image: Path to input SynthSeg segmentation
        output_image: Path to output tissue type map

    Returns:
        Absolute path to the output image

    Raises:
        FileNotFoundError: If input image doesn't exist
        FileIOError: If unable to read input or write output
    """
    from thesis.workflows.preprocess.operations.label_ops import (
        convert_synthseg_to_gm_wm_csf,
    )

    return convert_synthseg_to_gm_wm_csf(input_image, output_image)


def _create_tract_masks_wrapper(input_image: str, output_prefix: str) -> dict:
    """Wrapper for create_tractography_masks to use in Function node.

    Args:
        input_image: Path to input SynthSeg segmentation
        output_prefix: Path prefix for output masks (without extension)

    Returns:
        Dictionary mapping mask names to absolute output paths

    Raises:
        FileNotFoundError: If input image doesn't exist
        FileIOError: If unable to read input or write outputs
    """
    from thesis.workflows.preprocess.operations.label_ops import create_tractography_masks

    return create_tractography_masks(input_image, output_prefix)


[docs] def prepare_n4_bias_correction_node( dimension: int = 3, bspline_fitting_distance: int = 300, shrink_factor: int = 3, n_iterations: list[int] | None = None, name: str = "n4_bias_correction", ) -> Node: """Prepare Nipype node for ANTs N4 bias field correction. N4 is a variant of the popular N3 (nonparametric nonuniform normalization) retrospective bias correction algorithm. It improves on N3 by using a B-spline approximation to smooth the bias field. Args: dimension: Image dimension (2 or 3) bspline_fitting_distance: B-spline fitting distance in mm shrink_factor: Shrink factor for speed (higher = faster but less accurate) n_iterations: Number of iterations at each resolution level. Defaults to [50, 50, 30, 20]. name: Name for the Nipype node Returns: Configured Nipype Node wrapping ANTs N4BiasFieldCorrection Example: >>> node = prepare_n4_bias_correction_node( ... dimension=3, ... bspline_fitting_distance=300, ... shrink_factor=3, ... n_iterations=[50, 50, 30, 20] ... ) >>> node.inputs.input_image = "T1.nii.gz" >>> node.inputs.output_image = "T1_corrected.nii.gz" """ if n_iterations is None: n_iterations = [50, 50, 30, 20] logger.debug( f"Creating N4 bias correction node: dimension={dimension}, " f"bspline_fitting_distance={bspline_fitting_distance}, " f"shrink_factor={shrink_factor}, n_iterations={n_iterations}" ) node = Node(N4BiasFieldCorrection(), name=name) node.inputs.dimension = dimension node.inputs.bspline_fitting_distance = bspline_fitting_distance node.inputs.shrink_factor = shrink_factor node.inputs.n_iterations = n_iterations return node
[docs] def prepare_robust_fov_node(brainsize: int = 150, name: str = "robust_fov") -> Node: """Prepare Nipype node for FSL RobustFOV to crop field of view. RobustFOV reduces the field of view to remove lower head and neck regions, which can improve registration and reduce computation time. Args: brainsize: Brain size parameter in mm (typical range: 100-200) name: Name for the Nipype node Returns: Configured Nipype Node wrapping FSL RobustFOV Example: >>> node = prepare_robust_fov_node(brainsize=150) >>> node.inputs.in_file = "T1.nii.gz" >>> node.inputs.out_roi = "T1_cropped.nii.gz" """ logger.debug(f"Creating RobustFOV node: brainsize={brainsize}") node = Node(RobustFOV(), name=name) node.inputs.brainsize = brainsize return node
[docs] def prepare_threshold_node( thresh: float, direction: str = "below", name: str = "threshold" ) -> Node: """Prepare Nipype node for FSL image thresholding. Args: thresh: Threshold value direction: Thresholding direction - "below" zeros values below thresh, "above" zeros values above thresh name: Name for the Nipype node Returns: Configured Nipype Node wrapping FSL Threshold Raises: ValueError: If direction is not "below" or "above" Example: >>> node = prepare_threshold_node(thresh=0.5, direction="below") >>> node.inputs.in_file = "image.nii.gz" >>> node.inputs.out_file = "image_thresholded.nii.gz" """ if direction not in ("below", "above"): raise ValueError(f"direction must be 'below' or 'above', got: {direction}") logger.debug(f"Creating threshold node: thresh={thresh}, direction={direction}") node = Node(Threshold(), name=name) node.inputs.thresh = thresh node.inputs.direction = direction return node
[docs] def prepare_synthseg_node( parc: bool = True, robust: bool = True, fast: bool = True, name: str = "synthseg", ) -> Node: """Prepare Nipype node for FreeSurfer SynthSeg segmentation. SynthSeg is a learning-based tool for brain segmentation and parcellation that is robust to different MRI contrasts, resolutions, and acquisition artifacts. Args: parc: If True, output parcellation (191 labels); if False, segmentation (32 labels) robust: If True, use robust version for images with lesions and low contrast fast: If True, use fast approximate mode name: Name for the Nipype node Returns: Configured Nipype Node wrapping the custom SynthSeg interface Example: >>> node = prepare_synthseg_node(parc=True, robust=True, fast=True) >>> node.inputs.input_image = "T1.nii.gz" >>> node.inputs.output_segmentation = "synthseg.nii.gz" >>> node.inputs.output_volumes = "volumes.csv" >>> node.inputs.output_qc = "qc.csv" >>> node.inputs.output_resampled = "resampled.nii.gz" """ logger.debug(f"Creating SynthSeg node: parc={parc}, robust={robust}, fast={fast}") node = Node(SynthSeg(), name=name) # Set default parameters node.inputs.parc = parc node.inputs.robust = robust node.inputs.fast = fast return node
[docs] def prepare_gm_wm_csf_converter_node(name: str = "convert_gm_wm_csf") -> Node: """Prepare Nipype node to convert SynthSeg labels to GM/WM/CSF tissue types. This node converts a SynthSeg segmentation to a simplified tissue type map: - 0 = background - 1 = gray matter (GM) - 2 = white matter (WM) - 3 = cerebrospinal fluid (CSF) Args: name: Name for the Nipype node Returns: Configured Nipype Node wrapping convert_synthseg_to_gm_wm_csf via Function interface Example: >>> node = prepare_gm_wm_csf_converter_node() >>> node.inputs.input_image = "synthseg.nii.gz" >>> node.inputs.output_image = "tissue_types.nii.gz" """ logger.debug("Creating GM/WM/CSF converter node") node = Node( Function( input_names=["input_image", "output_image"], output_names=["output_image"], function=_convert_gm_wm_csf_wrapper, ), name=name, ) return node
[docs] def prepare_tract_masks_creator_node(name: str = "create_tract_masks") -> Node: """Prepare Nipype node to create tractography masks from SynthSeg labels. This node creates three binary masks for tractography: 1. Cerebellar right + cerebral left white matter 2. Cerebellar left + cerebral right white matter 3. CSF (exclusion mask) Args: name: Name for the Nipype node Returns: Configured Nipype Node wrapping create_tractography_masks via Function interface Example: >>> node = prepare_tract_masks_creator_node() >>> node.inputs.input_image = "synthseg.nii.gz" >>> node.inputs.output_prefix = "derivatives/masks/sub-01" """ logger.debug("Creating tractography masks creator node") node = Node( Function( input_names=["input_image", "output_prefix"], output_names=["output_paths"], function=_create_tract_masks_wrapper, ), name=name, ) return node