"""Label manipulation operations for SynthSeg segmentation masks.
This module provides pure Python implementations for converting SynthSeg
labels to tissue type maps and creating binary masks for tractography.
"""
from pathlib import Path
from typing import Dict, List
import nibabel as nib
import numpy as np
from thesis.core.exceptions import FileIOError
from thesis.core.logging import get_logger
__all__ = [
"convert_synthseg_to_gm_wm_csf",
"create_tractography_masks",
]
logger = get_logger(__name__)
# SynthSeg to tissue type mapping:
# 0 = background, 1 = GM (gray matter), 2 = WM (white matter), 3 = CSF
SYNTHSEG_TO_TISSUE_MAP = {
0: 0, # background
2: 2, # left cerebral white matter
3: 1, # left cerebral cortex
4: 3, # left lateral ventricle
5: 3, # left inferior lateral ventricle
7: 2, # left cerebellar white matter
8: 1, # left cerebellar cortex
10: 1, # left thalamus
11: 1, # left caudate
12: 1, # left putamen
13: 1, # left pallidum
14: 3, # 3rd ventricle
15: 3, # 4th ventricle
16: 1, # brain stem
17: 1, # left hippocampus
18: 1, # left amygdala
24: 3, # CSF
26: 1, # left accumbens area
28: 1, # left ventral DC
41: 2, # right cerebral white matter
42: 1, # right cerebral cortex
43: 3, # right lateral ventricle
44: 3, # right inferior lateral ventricle
46: 2, # right cerebellar white matter
47: 1, # right cerebellar cortex
49: 1, # right thalamus
50: 1, # right caudate
51: 1, # right putamen
52: 1, # right pallidum
53: 1, # right hippocampus
54: 1, # right amygdala
58: 1, # right accumbens area
60: 1, # right ventral DC
}
def _create_binary_mask_from_labels(label_image: np.ndarray, labels: List[int]) -> np.ndarray:
"""Create binary mask where voxels with any of the specified labels are set to 1.
Args:
label_image: Input label image as numpy array
labels: List of label values to include in the mask
Returns:
Binary mask as uint8 array (0 or 1)
Example:
>>> label_img = np.array([[[0, 1, 2], [3, 4, 5]]])
>>> mask = _create_binary_mask_from_labels(label_img, [2, 4])
>>> mask
array([[[0, 0, 1],
[0, 1, 0]]], dtype=uint8)
"""
mask = np.zeros_like(label_image, dtype=np.uint8)
for label in labels:
mask[label_image == label] = 1
return mask
[docs]
def convert_synthseg_to_gm_wm_csf(input_image: str, output_image: str) -> str:
"""Convert SynthSeg labels to GM/WM/CSF tissue type map.
Takes a SynthSeg segmentation and converts it to a simplified tissue type
map using the following mapping:
- 0 = background
- 1 = gray matter (GM)
- 2 = white matter (WM)
- 3 = cerebrospinal fluid (CSF)
All label values > 1000 are mapped to 1 (GM) to handle FreeSurfer-style
high-value labels.
Args:
input_image: Path to input SynthSeg segmentation (.nii or .nii.gz)
output_image: Path to output tissue type map (.nii or .nii.gz)
Returns:
Absolute path to the output image
Raises:
FileNotFoundError: If input image doesn't exist
FileIOError: If unable to read input or write output
Example:
>>> output_path = convert_synthseg_to_gm_wm_csf(
... "synthseg.nii.gz",
... "tissue_types.nii.gz"
... )
>>> print(f"Created tissue map: {output_path}")
"""
input_path = Path(input_image).resolve()
output_path = Path(output_image).resolve()
# Check input exists
if not input_path.exists():
raise FileNotFoundError(f"Input image not found: {input_path}")
logger.info(f"Converting SynthSeg labels to GM/WM/CSF: {input_path.name}")
try:
# Load input image
img = nib.load(str(input_path))
data = img.get_fdata().astype(np.int32) # type: ignore[attr-defined]
# Create output array
output_data = np.zeros_like(data, dtype=np.uint8)
# Apply label mapping
for synthseg_label, tissue_type in SYNTHSEG_TO_TISSUE_MAP.items():
output_data[data == synthseg_label] = tissue_type
# Map all values > 1000 to GM (1)
# This handles FreeSurfer-style high-value labels
output_data[data > 1000] = 1
# Create output directory if needed
output_path.parent.mkdir(parents=True, exist_ok=True)
# Save output image with same header/affine as input
output_img = nib.Nifti1Image(
output_data, img.affine, img.header # type: ignore[attr-defined]
)
nib.save(output_img, str(output_path))
logger.info(f"Created tissue type map: {output_path.name}")
return str(output_path)
except Exception as e:
if isinstance(e, FileNotFoundError):
raise
raise FileIOError(f"Failed to convert SynthSeg labels: {e}") from e
[docs]
def create_tractography_masks(input_image: str, output_prefix: str) -> Dict[str, str]:
"""Create three binary masks for tractography from SynthSeg labels.
Creates the following masks:
1. Cerebellar right + cerebral left: labels [2, 46]
2. Cerebellar left + cerebral right: labels [7, 41]
3. CSF (exclusion mask): labels [4, 5, 14, 15, 24, 43, 44]
Output files are named:
- {output_prefix}_cerebellar_right_cerebral_left.nii.gz
- {output_prefix}_cerebellar_left_cerebral_right.nii.gz
- {output_prefix}_csf.nii.gz
Args:
input_image: Path to input SynthSeg segmentation (.nii or .nii.gz)
output_prefix: Path prefix for output masks (without extension)
Returns:
Dictionary mapping mask names to absolute output paths:
- "cerebellar_right_cerebral_left": right cerebellar + left cerebral WM
- "cerebellar_left_cerebral_right": left cerebellar + right cerebral WM
- "csf": CSF exclusion mask
Raises:
FileNotFoundError: If input image doesn't exist
FileIOError: If unable to read input or write outputs
Example:
>>> masks = create_tractography_masks(
... "synthseg.nii.gz",
... "derivatives/masks/sub-01"
... )
>>> for name, path in masks.items():
... print(f"{name}: {path}")
"""
input_path = Path(input_image).resolve()
output_prefix_path = Path(output_prefix).resolve()
# Check input exists
if not input_path.exists():
raise FileNotFoundError(f"Input image not found: {input_path}")
logger.info(f"Creating tractography masks from: {input_path.name}")
try:
# Load input image
img = nib.load(str(input_path))
data = img.get_fdata().astype(np.int32) # type: ignore[attr-defined]
# Create output directory if needed
output_prefix_path.parent.mkdir(parents=True, exist_ok=True)
# Define masks and their corresponding labels
mask_definitions = {
"cerebellar_right_cerebral_left": [2, 46], # Left cerebral WM + Right cerebellar WM
"cerebellar_left_cerebral_right": [7, 41], # Left cerebellar WM + Right cerebral WM
"csf": [4, 5, 14, 15, 24, 43, 44], # All CSF compartments
}
output_paths = {}
# Create each mask
for mask_name, labels in mask_definitions.items():
# Create binary mask
mask_data = _create_binary_mask_from_labels(data, labels)
# Generate output path
output_path = (
output_prefix_path.parent / f"{output_prefix_path.name}_{mask_name}.nii.gz"
)
# Save mask with same header/affine as input
mask_img = nib.Nifti1Image(
mask_data, img.affine, img.header # type: ignore[attr-defined]
)
nib.save(mask_img, str(output_path))
output_paths[mask_name] = str(output_path)
logger.debug(
f"Created {mask_name} mask with {np.sum(mask_data)} voxels: {output_path.name}"
)
logger.info(f"Created {len(output_paths)} tractography masks")
return output_paths
except Exception as e:
if isinstance(e, FileNotFoundError):
raise
raise FileIOError(f"Failed to create tractography masks: {e}") from e