Source code for thesis.workflows.hcp.operations.extraction
"""ROI extraction from label maps.
NOTE: The main function ``extract_rois_task`` runs inside a Nipype Function node
(potentially in a separate process), so the thesis logger is not available at
runtime. ``sys.stderr.write()`` is used for progress messages instead.
"""
[docs]
def extract_rois_task(
roi_file: str,
label_file: str,
waypoint_labels: dict,
output_dir: str,
hemisphere: str = "both",
) -> tuple[str, str, str, str, str]:
"""Extract ROI masks from a label map and assemble workflow-ready outputs.
Args:
roi_file: Input label-map image path.
label_file: Optional CSV or whitespace-delimited label mapping file.
waypoint_labels: ROI extraction specification keyed by ROI name. May
contain hemisphere-specific keys (``left_label_name`` etc.) that
this task resolves at runtime via ``hemisphere``.
output_dir: Output directory for generated masks and waypoint list files.
hemisphere: ``"left"``, ``"right"``, or ``"both"`` (default). When
``waypoint_labels`` entries declare hemisphere-specific labels,
this selects which side(s) to extract; entries without
hemisphere-specific fields pass through unchanged.
Returns:
Tuple of ``(seed, waypoints_file, stop_mask, avoid_mask, target_mask)``.
Raises:
FileNotFoundError: If required FSL commands are unavailable or inputs are missing.
RuntimeError: If an FSL ROI extraction command fails.
"""
import sys
from pathlib import Path
from typing import Dict, List
from thesis.workflows.hcp.common import filter_waypoint_labels_by_hemisphere
from thesis.workflows.hcp.operations._fsl import run_fsl_command
waypoint_labels = filter_waypoint_labels_by_hemisphere(waypoint_labels, hemisphere)
# Per-hemisphere subdir so that parallel left/right iterations of the same
# extractor MapNode do not race on identically-named output files
# (e.g. rois_synthseg/hemisphere_avoid.nii.gz). "both" keeps the legacy
# flat layout.
out_path = Path(output_dir)
if hemisphere in ("left", "right"):
out_path = out_path / hemisphere
out_path.mkdir(parents=True, exist_ok=True)
needs_csv = any(
info.get("label_name") or info.get("label_names") for info in waypoint_labels.values()
)
mapping: Dict[str, str] = {}
if needs_csv and label_file:
import csv
with open(label_file, "r", encoding="utf-8") as handle:
raw_text = handle.read()
lines = [line.strip() for line in raw_text.splitlines() if line.strip()]
if lines:
first_line = lines[0]
if "Name" in first_line and "LabelValue" in first_line:
reader = csv.DictReader(raw_text.splitlines())
for row in reader:
if "Name" in row and "LabelValue" in row:
mapping[row["Name"]] = row["LabelValue"]
else:
for line in lines:
parts = line.split(None, 1)
if len(parts) != 2:
continue
parsed_value, parsed_name = parts
if parsed_value.lstrip("-").isdigit():
mapping[parsed_name.strip()] = parsed_value.strip()
def extract_multi_label(values: list[int], output_mask: str) -> None:
"""Create a binary mask covering all listed label integer values.
Single nibabel pass: loads the label map once, builds the union of
the requested label values via ``np.isin``, and writes a binarized
(0/1) mask that preserves the source image's affine, header, and
on-disk dtype. Replaces the previous 2N-1 ``fslmaths`` subprocess
chain while producing an identical result.
"""
import nibabel as nib
import numpy as np
img = nib.load(str(roi_file))
data = np.asanyarray(img.dataobj) # type: ignore[attr-defined]
mask = np.isin(data, values).astype(data.dtype)
out_img = nib.Nifti1Image(mask, img.affine, img.header) # type: ignore[attr-defined]
nib.save(out_img, output_mask)
outputs: Dict[str, List[str]] = {
"seed": [],
"waypoint": [],
"stop": [],
"avoid": [],
"target": [],
}
for name, info in waypoint_labels.items():
kind = info.get("region_kind")
if kind not in outputs:
continue
mask_path = str(out_path / f"{name}.nii.gz")
label_values = info.get("label_values")
label_name = info.get("label_name")
label_names: list[str] | None = info.get("label_names")
if label_values:
extract_multi_label([int(value) for value in label_values], mask_path)
elif label_names:
# Multiple label names (e.g. left + right merged into one mask)
resolved_values: list[int] = []
for lname in label_names:
lval: str | None = mapping.get(lname)
if lval is None or lval == "":
sys.stderr.write(
f"Warning: label '{lname}' was not found in {label_file}; "
f"skipping label for ROI '{name}'.\n"
)
sys.stderr.flush()
continue
resolved_values.append(int(lval))
if not resolved_values:
sys.stderr.write(f"Warning: no valid labels resolved for ROI '{name}'; skipping.\n")
sys.stderr.flush()
continue
extract_multi_label(resolved_values, mask_path)
elif label_name:
label_value: str | None = mapping.get(label_name)
if label_value is None or label_value == "":
sys.stderr.write(
f"Warning: label '{label_name}' was not found in {label_file}; "
f"skipping ROI '{name}'.\n"
)
sys.stderr.flush()
continue
run_fsl_command(
[
"fslmaths",
str(roi_file),
"-thr",
str(label_value),
"-uthr",
str(label_value),
"-bin",
mask_path,
]
)
else:
continue
outputs[kind].append(mask_path)
seed = ""
if len(outputs["seed"]) > 1:
seed = str(out_path / "seed_merged.nii.gz")
first_seed, *other_seeds = outputs["seed"]
run_fsl_command(["fslmaths", first_seed, "-bin", seed])
for seed_mask in other_seeds:
run_fsl_command(["fslmaths", seed, "-add", seed_mask, "-bin", seed])
elif outputs["seed"]:
seed = outputs["seed"][0]
waypoints_file = ""
if outputs["waypoint"]:
waypoints_file = str(out_path / "waypoints.txt")
with open(waypoints_file, "w", encoding="utf-8") as handle:
for waypoint in outputs["waypoint"]:
handle.write(waypoint + "\n")
stop_mask = outputs["stop"][0] if outputs["stop"] else ""
avoid_mask = outputs["avoid"][0] if outputs["avoid"] else ""
target_mask = outputs["target"][0] if outputs["target"] else ""
return seed, waypoints_file, stop_mask, avoid_mask, target_mask