"""Node builders for structural preprocessing workflows.
This module provides node builders for structural MRI preprocessing including:
- N4 bias field correction (ANTs)
- Robust field-of-view cropping (FSL)
- Image thresholding (FSL)
- SynthSeg segmentation (FreeSurfer)
- Tissue type conversion (GM/WM/CSF)
- Tractography mask creation
"""
from nipype import Node
from nipype.interfaces.ants import N4BiasFieldCorrection
from nipype.interfaces.fsl import RobustFOV, Threshold
from nipype.interfaces.utility import Function
from thesis.core.logging import get_logger
from thesis.core.nipype.interfaces.freesurfer import SynthSeg
__all__ = [
"prepare_n4_bias_correction_node",
"prepare_robust_fov_node",
"prepare_threshold_node",
"prepare_synthseg_node",
"prepare_gm_wm_csf_converter_node",
"prepare_tract_masks_creator_node",
]
logger = get_logger(__name__)
def _convert_gm_wm_csf_wrapper(input_image: str, output_image: str) -> str:
"""Wrapper for convert_synthseg_to_gm_wm_csf to use in Function node.
Args:
input_image: Path to input SynthSeg segmentation
output_image: Path to output tissue type map
Returns:
Absolute path to the output image
Raises:
FileNotFoundError: If input image doesn't exist
FileIOError: If unable to read input or write output
"""
from thesis.workflows.preprocess.operations.label_ops import (
convert_synthseg_to_gm_wm_csf,
)
return convert_synthseg_to_gm_wm_csf(input_image, output_image)
def _create_tract_masks_wrapper(input_image: str, output_prefix: str) -> dict:
"""Wrapper for create_tractography_masks to use in Function node.
Args:
input_image: Path to input SynthSeg segmentation
output_prefix: Path prefix for output masks (without extension)
Returns:
Dictionary mapping mask names to absolute output paths
Raises:
FileNotFoundError: If input image doesn't exist
FileIOError: If unable to read input or write outputs
"""
from thesis.workflows.preprocess.operations.label_ops import create_tractography_masks
return create_tractography_masks(input_image, output_prefix)
[docs]
def prepare_n4_bias_correction_node(
dimension: int = 3,
bspline_fitting_distance: int = 300,
shrink_factor: int = 3,
n_iterations: list[int] | None = None,
name: str = "n4_bias_correction",
) -> Node:
"""Prepare Nipype node for ANTs N4 bias field correction.
N4 is a variant of the popular N3 (nonparametric nonuniform normalization)
retrospective bias correction algorithm. It improves on N3 by using a
B-spline approximation to smooth the bias field.
Args:
dimension: Image dimension (2 or 3)
bspline_fitting_distance: B-spline fitting distance in mm
shrink_factor: Shrink factor for speed (higher = faster but less accurate)
n_iterations: Number of iterations at each resolution level. Defaults to [50, 50, 30, 20].
name: Name for the Nipype node
Returns:
Configured Nipype Node wrapping ANTs N4BiasFieldCorrection
Example:
>>> node = prepare_n4_bias_correction_node(
... dimension=3,
... bspline_fitting_distance=300,
... shrink_factor=3,
... n_iterations=[50, 50, 30, 20]
... )
>>> node.inputs.input_image = "T1.nii.gz"
>>> node.inputs.output_image = "T1_corrected.nii.gz"
"""
if n_iterations is None:
n_iterations = [50, 50, 30, 20]
logger.debug(
f"Creating N4 bias correction node: dimension={dimension}, "
f"bspline_fitting_distance={bspline_fitting_distance}, "
f"shrink_factor={shrink_factor}, n_iterations={n_iterations}"
)
node = Node(N4BiasFieldCorrection(), name=name)
node.inputs.dimension = dimension
node.inputs.bspline_fitting_distance = bspline_fitting_distance
node.inputs.shrink_factor = shrink_factor
node.inputs.n_iterations = n_iterations
return node
[docs]
def prepare_robust_fov_node(brainsize: int = 150, name: str = "robust_fov") -> Node:
"""Prepare Nipype node for FSL RobustFOV to crop field of view.
RobustFOV reduces the field of view to remove lower head and neck regions,
which can improve registration and reduce computation time.
Args:
brainsize: Brain size parameter in mm (typical range: 100-200)
name: Name for the Nipype node
Returns:
Configured Nipype Node wrapping FSL RobustFOV
Example:
>>> node = prepare_robust_fov_node(brainsize=150)
>>> node.inputs.in_file = "T1.nii.gz"
>>> node.inputs.out_roi = "T1_cropped.nii.gz"
"""
logger.debug(f"Creating RobustFOV node: brainsize={brainsize}")
node = Node(RobustFOV(), name=name)
node.inputs.brainsize = brainsize
return node
[docs]
def prepare_threshold_node(
thresh: float, direction: str = "below", name: str = "threshold"
) -> Node:
"""Prepare Nipype node for FSL image thresholding.
Args:
thresh: Threshold value
direction: Thresholding direction - "below" zeros values below thresh,
"above" zeros values above thresh
name: Name for the Nipype node
Returns:
Configured Nipype Node wrapping FSL Threshold
Raises:
ValueError: If direction is not "below" or "above"
Example:
>>> node = prepare_threshold_node(thresh=0.5, direction="below")
>>> node.inputs.in_file = "image.nii.gz"
>>> node.inputs.out_file = "image_thresholded.nii.gz"
"""
if direction not in ("below", "above"):
raise ValueError(f"direction must be 'below' or 'above', got: {direction}")
logger.debug(f"Creating threshold node: thresh={thresh}, direction={direction}")
node = Node(Threshold(), name=name)
node.inputs.thresh = thresh
node.inputs.direction = direction
return node
[docs]
def prepare_synthseg_node(
parc: bool = True,
robust: bool = True,
fast: bool = True,
name: str = "synthseg",
) -> Node:
"""Prepare Nipype node for FreeSurfer SynthSeg segmentation.
SynthSeg is a learning-based tool for brain segmentation and parcellation
that is robust to different MRI contrasts, resolutions, and acquisition artifacts.
Args:
parc: If True, output parcellation (191 labels); if False, segmentation (32 labels)
robust: If True, use robust version for images with lesions and low contrast
fast: If True, use fast approximate mode
name: Name for the Nipype node
Returns:
Configured Nipype Node wrapping the custom SynthSeg interface
Example:
>>> node = prepare_synthseg_node(parc=True, robust=True, fast=True)
>>> node.inputs.input_image = "T1.nii.gz"
>>> node.inputs.output_segmentation = "synthseg.nii.gz"
>>> node.inputs.output_volumes = "volumes.csv"
>>> node.inputs.output_qc = "qc.csv"
>>> node.inputs.output_resampled = "resampled.nii.gz"
"""
logger.debug(f"Creating SynthSeg node: parc={parc}, robust={robust}, fast={fast}")
node = Node(SynthSeg(), name=name)
# Set default parameters
node.inputs.parc = parc
node.inputs.robust = robust
node.inputs.fast = fast
return node
[docs]
def prepare_gm_wm_csf_converter_node(name: str = "convert_gm_wm_csf") -> Node:
"""Prepare Nipype node to convert SynthSeg labels to GM/WM/CSF tissue types.
This node converts a SynthSeg segmentation to a simplified tissue type map:
- 0 = background
- 1 = gray matter (GM)
- 2 = white matter (WM)
- 3 = cerebrospinal fluid (CSF)
Args:
name: Name for the Nipype node
Returns:
Configured Nipype Node wrapping convert_synthseg_to_gm_wm_csf via Function interface
Example:
>>> node = prepare_gm_wm_csf_converter_node()
>>> node.inputs.input_image = "synthseg.nii.gz"
>>> node.inputs.output_image = "tissue_types.nii.gz"
"""
logger.debug("Creating GM/WM/CSF converter node")
node = Node(
Function(
input_names=["input_image", "output_image"],
output_names=["output_image"],
function=_convert_gm_wm_csf_wrapper,
),
name=name,
)
return node
[docs]
def prepare_tract_masks_creator_node(name: str = "create_tract_masks") -> Node:
"""Prepare Nipype node to create tractography masks from SynthSeg labels.
This node creates three binary masks for tractography:
1. Cerebellar right + cerebral left white matter
2. Cerebellar left + cerebral right white matter
3. CSF (exclusion mask)
Args:
name: Name for the Nipype node
Returns:
Configured Nipype Node wrapping create_tractography_masks via Function interface
Example:
>>> node = prepare_tract_masks_creator_node()
>>> node.inputs.input_image = "synthseg.nii.gz"
>>> node.inputs.output_prefix = "derivatives/masks/sub-01"
"""
logger.debug("Creating tractography masks creator node")
node = Node(
Function(
input_names=["input_image", "output_prefix"],
output_names=["output_paths"],
function=_create_tract_masks_wrapper,
),
name=name,
)
return node