Source code for thesis.workflows.preprocess.operations.label_ops

"""Label manipulation operations for SynthSeg segmentation masks.

This module provides pure Python implementations for converting SynthSeg
labels to tissue type maps and creating binary masks for tractography.
"""

from pathlib import Path
from typing import Dict, List

import nibabel as nib
import numpy as np

from thesis.core.exceptions import FileIOError
from thesis.core.logging import get_logger

__all__ = [
    "convert_synthseg_to_gm_wm_csf",
    "create_tractography_masks",
]

logger = get_logger(__name__)


# SynthSeg to tissue type mapping:
# 0 = background, 1 = GM (gray matter), 2 = WM (white matter), 3 = CSF
SYNTHSEG_TO_TISSUE_MAP = {
    0: 0,  # background
    2: 2,  # left cerebral white matter
    3: 1,  # left cerebral cortex
    4: 3,  # left lateral ventricle
    5: 3,  # left inferior lateral ventricle
    7: 2,  # left cerebellar white matter
    8: 1,  # left cerebellar cortex
    10: 1,  # left thalamus
    11: 1,  # left caudate
    12: 1,  # left putamen
    13: 1,  # left pallidum
    14: 3,  # 3rd ventricle
    15: 3,  # 4th ventricle
    16: 1,  # brain stem
    17: 1,  # left hippocampus
    18: 1,  # left amygdala
    24: 3,  # CSF
    26: 1,  # left accumbens area
    28: 1,  # left ventral DC
    41: 2,  # right cerebral white matter
    42: 1,  # right cerebral cortex
    43: 3,  # right lateral ventricle
    44: 3,  # right inferior lateral ventricle
    46: 2,  # right cerebellar white matter
    47: 1,  # right cerebellar cortex
    49: 1,  # right thalamus
    50: 1,  # right caudate
    51: 1,  # right putamen
    52: 1,  # right pallidum
    53: 1,  # right hippocampus
    54: 1,  # right amygdala
    58: 1,  # right accumbens area
    60: 1,  # right ventral DC
}


def _create_binary_mask_from_labels(label_image: np.ndarray, labels: List[int]) -> np.ndarray:
    """Create binary mask where voxels with any of the specified labels are set to 1.

    Args:
        label_image: Input label image as numpy array
        labels: List of label values to include in the mask

    Returns:
        Binary mask as uint8 array (0 or 1)

    Example:
        >>> label_img = np.array([[[0, 1, 2], [3, 4, 5]]])
        >>> mask = _create_binary_mask_from_labels(label_img, [2, 4])
        >>> mask
        array([[[0, 0, 1],
                [0, 1, 0]]], dtype=uint8)
    """
    mask = np.zeros_like(label_image, dtype=np.uint8)
    for label in labels:
        mask[label_image == label] = 1
    return mask


[docs] def convert_synthseg_to_gm_wm_csf(input_image: str, output_image: str) -> str: """Convert SynthSeg labels to GM/WM/CSF tissue type map. Takes a SynthSeg segmentation and converts it to a simplified tissue type map using the following mapping: - 0 = background - 1 = gray matter (GM) - 2 = white matter (WM) - 3 = cerebrospinal fluid (CSF) All label values > 1000 are mapped to 1 (GM) to handle FreeSurfer-style high-value labels. Args: input_image: Path to input SynthSeg segmentation (.nii or .nii.gz) output_image: Path to output tissue type map (.nii or .nii.gz) Returns: Absolute path to the output image Raises: FileNotFoundError: If input image doesn't exist FileIOError: If unable to read input or write output Example: >>> output_path = convert_synthseg_to_gm_wm_csf( ... "synthseg.nii.gz", ... "tissue_types.nii.gz" ... ) >>> print(f"Created tissue map: {output_path}") """ input_path = Path(input_image).resolve() output_path = Path(output_image).resolve() # Check input exists if not input_path.exists(): raise FileNotFoundError(f"Input image not found: {input_path}") logger.info(f"Converting SynthSeg labels to GM/WM/CSF: {input_path.name}") try: # Load input image img = nib.load(str(input_path)) data = img.get_fdata().astype(np.int32) # type: ignore[attr-defined] # Create output array output_data = np.zeros_like(data, dtype=np.uint8) # Apply label mapping for synthseg_label, tissue_type in SYNTHSEG_TO_TISSUE_MAP.items(): output_data[data == synthseg_label] = tissue_type # Map all values > 1000 to GM (1) # This handles FreeSurfer-style high-value labels output_data[data > 1000] = 1 # Create output directory if needed output_path.parent.mkdir(parents=True, exist_ok=True) # Save output image with same header/affine as input output_img = nib.Nifti1Image( output_data, img.affine, img.header # type: ignore[attr-defined] ) nib.save(output_img, str(output_path)) logger.info(f"Created tissue type map: {output_path.name}") return str(output_path) except Exception as e: if isinstance(e, FileNotFoundError): raise raise FileIOError(f"Failed to convert SynthSeg labels: {e}") from e
[docs] def create_tractography_masks(input_image: str, output_prefix: str) -> Dict[str, str]: """Create three binary masks for tractography from SynthSeg labels. Creates the following masks: 1. Cerebellar right + cerebral left: labels [2, 46] 2. Cerebellar left + cerebral right: labels [7, 41] 3. CSF (exclusion mask): labels [4, 5, 14, 15, 24, 43, 44] Output files are named: - {output_prefix}_cerebellar_right_cerebral_left.nii.gz - {output_prefix}_cerebellar_left_cerebral_right.nii.gz - {output_prefix}_csf.nii.gz Args: input_image: Path to input SynthSeg segmentation (.nii or .nii.gz) output_prefix: Path prefix for output masks (without extension) Returns: Dictionary mapping mask names to absolute output paths: - "cerebellar_right_cerebral_left": right cerebellar + left cerebral WM - "cerebellar_left_cerebral_right": left cerebellar + right cerebral WM - "csf": CSF exclusion mask Raises: FileNotFoundError: If input image doesn't exist FileIOError: If unable to read input or write outputs Example: >>> masks = create_tractography_masks( ... "synthseg.nii.gz", ... "derivatives/masks/sub-01" ... ) >>> for name, path in masks.items(): ... print(f"{name}: {path}") """ input_path = Path(input_image).resolve() output_prefix_path = Path(output_prefix).resolve() # Check input exists if not input_path.exists(): raise FileNotFoundError(f"Input image not found: {input_path}") logger.info(f"Creating tractography masks from: {input_path.name}") try: # Load input image img = nib.load(str(input_path)) data = img.get_fdata().astype(np.int32) # type: ignore[attr-defined] # Create output directory if needed output_prefix_path.parent.mkdir(parents=True, exist_ok=True) # Define masks and their corresponding labels mask_definitions = { "cerebellar_right_cerebral_left": [2, 46], # Left cerebral WM + Right cerebellar WM "cerebellar_left_cerebral_right": [7, 41], # Left cerebellar WM + Right cerebral WM "csf": [4, 5, 14, 15, 24, 43, 44], # All CSF compartments } output_paths = {} # Create each mask for mask_name, labels in mask_definitions.items(): # Create binary mask mask_data = _create_binary_mask_from_labels(data, labels) # Generate output path output_path = ( output_prefix_path.parent / f"{output_prefix_path.name}_{mask_name}.nii.gz" ) # Save mask with same header/affine as input mask_img = nib.Nifti1Image( mask_data, img.affine, img.header # type: ignore[attr-defined] ) nib.save(mask_img, str(output_path)) output_paths[mask_name] = str(output_path) logger.debug( f"Created {mask_name} mask with {np.sum(mask_data)} voxels: {output_path.name}" ) logger.info(f"Created {len(output_paths)} tractography masks") return output_paths except Exception as e: if isinstance(e, FileNotFoundError): raise raise FileIOError(f"Failed to create tractography masks: {e}") from e