"""Atlas generation — orchestration layer.
This module provides the main entry point for generating statistical atlases
from cohort tractography data using numpy for computation.
Key features:
- Full cohort stack loaded via nibabel and stacked into a (n_subjects, X, Y, Z) array
- Vectorised statistics computation via NumPy (mean, std, cov, prob_threshold)
"""
import sys
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, cast
import nibabel as nib
from nibabel.nifti1 import Nifti1Header
from thesis.workflows.atlas._io import (
_build_patient_stack,
_collect_patient_inputs,
)
from thesis.workflows.atlas._params import (
ATLAS_FILENAME_MAP,
ATLAS_STATISTIC_NAMES,
AtlasParams,
)
from thesis.workflows.atlas._qc import _prepare_qc_dir, _run_qc_from_stack
from thesis.workflows.atlas._statistics import compute_atlas_statistics
[docs]
def generate_statistical_atlas(
output_dir: str,
atlas_dir: str,
params: AtlasParams,
tractography_relpath: str = "tractography/probtrackx2",
) -> List[str]:
"""Generate a statistical reference atlas from cohort tractography data.
Discovers all patient warped fdt_paths.nii.gz files, normalises them using
the configured method (waytotal, max, or softmax), stacks them into a
(n_subjects, X, Y, Z) array, and computes five essential statistics: mean,
std, std_error, cov, and prob_threshold using vectorised NumPy operations.
Args:
output_dir: Base output directory containing patient subdirectories.
atlas_dir: Directory where atlas NIfTI files will be saved.
params: Runtime parameters including normalization method and thresholds.
tractography_relpath: Relative path under each patient directory
where tractography output lives.
Returns:
List of absolute paths to the generated NIfTI files.
Raises:
ValueError: If no valid patient data is found.
"""
out_path = Path(output_dir)
atlas_path = Path(atlas_dir)
atlas_path.mkdir(parents=True, exist_ok=True)
# Discover patient data
patient_inputs = _collect_patient_inputs(out_path, tractography_relpath)
print(f"Found {len(patient_inputs)} patients for atlas generation.", file=sys.stderr)
# Get reference image metadata from the first patient
ref_file: Path = patient_inputs[0][0][0]
ref_img = nib.load(ref_file)
affine = ref_img.affine # type: ignore[attr-defined]
header = ref_img.header
ref_shape: tuple[int, int, int] = ref_img.shape[:3] # type: ignore[attr-defined]
print(f"Reference shape: {ref_shape}", file=sys.stderr)
print(f"Normalization method: {params.normalization_method.value}", file=sys.stderr)
# Load all volumes into a (n_subjects, X, Y, Z) stack and compute statistics
print("Loading patient volumes and computing atlas statistics...", file=sys.stderr)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)
stack = _build_patient_stack(patient_inputs, ref_shape, params.normalization_method)
computed_results = compute_atlas_statistics(
stack,
presence_value=params.presence_value,
cov_mean_threshold_pct=params.cov_mean_threshold_pct,
)
# Save results
generated: List[str] = []
for name in ATLAS_STATISTIC_NAMES:
out_file = atlas_path / ATLAS_FILENAME_MAP[name]
nib.save(nib.Nifti1Image(computed_results[name], affine, header), out_file)
generated.append(str(out_file))
print(f" Saved {out_file}", file=sys.stderr)
return generated
[docs]
@dataclass(frozen=True)
class QcParams:
"""Runtime parameters for the atlas QC pass.
Mirrors the per-field inputs of :func:`thesis.workflows.atlas._qc.generate_atlas_qc`
so the merged single-pass atlas+QC computation can forward them unchanged.
Attributes:
subject_density_threshold: Intensity threshold used to binarize maps.
group_core_threshold: Occupancy threshold [0, 1] for the group core mask.
leave_one_out: Whether to compute leave-one-out Dice summaries.
leave_one_out_min_subjects: Minimum subjects required for leave-one-out.
compute_cv: Whether to compute a voxelwise coefficient of variation map.
cv_mean_threshold: Mean threshold applied when computing the CV map.
compute_log_stats: Whether to compute log-space summary metrics.
log_offset: Positive offset added before log-transform summaries.
qc_subdir: Subdirectory name under ``output_dir`` for QC artefacts.
"""
subject_density_threshold: float
group_core_threshold: float
leave_one_out: bool
leave_one_out_min_subjects: int
compute_cv: bool
cv_mean_threshold: float
compute_log_stats: bool
log_offset: float
qc_subdir: str = "atlas_qc"
[docs]
def generate_atlas_with_qc(
output_dir: str,
atlas_dir: str,
params: AtlasParams,
qc_params: Optional[QcParams] = None,
qc_output_dir: Optional[str] = None,
tractography_relpath: str = "tractography/probtrackx2",
) -> List[str]:
"""Generate the statistical atlas and (optionally) QC in a single pass.
The cohort is discovered, normalised, and stacked into a single
``(n_subjects, X, Y, Z)`` array exactly once. Both the statistical-atlas
outputs and the QC outputs are derived from that one in-memory stack,
avoiding the redundant collect/stack/mean-std that the separate atlas and
QC nodes previously performed. Only one stack is held in RAM at a time.
Every output file (atlas statistics, occupancy, core mask, optional CV map,
and the QC JSON summary) is written to the same paths with the same numeric
values as the previous two-node implementation.
Args:
output_dir: Base output directory containing patient subdirectories
(the cohort discovery root).
atlas_dir: Directory where atlas NIfTI files will be saved.
params: Atlas runtime parameters (normalization, thresholds).
qc_params: Optional QC runtime parameters. When ``None``, QC is skipped
and only the statistical atlas is produced.
qc_output_dir: Base directory under which the QC subdirectory is created.
Required when ``qc_params`` is provided.
tractography_relpath: Relative path under each patient directory
where tractography output lives.
Returns:
List of absolute paths to all generated NIfTI/JSON files (atlas first,
then QC outputs when QC is enabled).
Raises:
ValueError: If no valid patient data is found, or if ``qc_params`` is
supplied without ``qc_output_dir``.
"""
out_path = Path(output_dir)
atlas_path = Path(atlas_dir)
atlas_path.mkdir(parents=True, exist_ok=True)
if qc_params is not None and qc_output_dir is None:
raise ValueError("qc_output_dir must be provided when qc_params is set")
# Discover patient data (single pass)
patient_inputs = _collect_patient_inputs(out_path, tractography_relpath)
print(f"Found {len(patient_inputs)} patients for atlas generation.", file=sys.stderr)
# Reference image metadata from the first patient
ref_file: Path = patient_inputs[0][0][0]
ref_img = nib.load(ref_file)
affine = ref_img.affine # type: ignore[attr-defined]
header = ref_img.header
ref_shape: tuple[int, int, int] = ref_img.shape[:3] # type: ignore[attr-defined]
print(f"Reference shape: {ref_shape}", file=sys.stderr)
print(f"Normalization method: {params.normalization_method.value}", file=sys.stderr)
# Load the cohort stack ONCE and compute atlas statistics from it.
print("Loading patient volumes and computing atlas statistics...", file=sys.stderr)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)
stack = _build_patient_stack(patient_inputs, ref_shape, params.normalization_method)
computed_results = compute_atlas_statistics(
stack,
presence_value=params.presence_value,
cov_mean_threshold_pct=params.cov_mean_threshold_pct,
)
# Save atlas results
generated: List[str] = []
for name in ATLAS_STATISTIC_NAMES:
out_file = atlas_path / ATLAS_FILENAME_MAP[name]
nib.save(nib.Nifti1Image(computed_results[name], affine, header), out_file)
generated.append(str(out_file))
print(f" Saved {out_file}", file=sys.stderr)
# Derive QC from the SAME in-memory stack.
if qc_params is not None:
assert qc_output_dir is not None # narrowed by the guard above
qc_dir = _prepare_qc_dir(qc_output_dir, qc_params.qc_subdir)
generated.extend(
_run_qc_from_stack(
subject_value_maps=stack,
affine=affine,
header=cast(Nifti1Header, header),
qc_dir=qc_dir,
subject_density_threshold=qc_params.subject_density_threshold,
group_core_threshold=qc_params.group_core_threshold,
leave_one_out=qc_params.leave_one_out,
leave_one_out_min_subjects=qc_params.leave_one_out_min_subjects,
compute_cv=qc_params.compute_cv,
cv_mean_threshold=qc_params.cv_mean_threshold,
compute_log_stats=qc_params.compute_log_stats,
log_offset=qc_params.log_offset,
)
)
return generated