Source code for thesis.workflows.learned_atlas.train

"""Learned conditional deformable atlas - training entry + atlas emission.

:func:`train_learned_atlas` is called by the ``learned_atlas`` Nipype Function
node. It loads the cohort tract-density stack (reusing the averaging-atlas IO
and discovery so subject IDs and stack order stay aligned), trains a learnable
template ``T`` plus a deformation-only diffeomorphic network, then emits the
five atlas maps byte-compatibly with the averaging atlas.

Byte-compatibility: the output affine/header are taken from the first cohort
fdt image exactly as :func:`thesis.workflows.atlas.compute.generate_statistical_atlas`
does (NOT from a pre-existing atlas file), and ``mean`` is the learned template
while std/std_error/cov/prob_threshold come from
:func:`thesis.workflows.atlas._statistics.compute_atlas_statistics` over the
warped-template population - identical keys, dtypes (float32) and formulae.

Gotchas: all heavy imports are LOCAL (the body is pickled into a Nipype
subprocess and loguru cannot be pickled); diagnostics use
``print(..., file=sys.stderr)``; torch is the optional ``ml`` extra; VRAM is
released before returning.
"""

from __future__ import annotations

from typing import List, Optional


[docs] def train_learned_atlas( input_dir: str, atlas_dir: str, tractography_relpath: str, normalization_method: str, training_space: str, affine_native_relpath: Optional[str], emit_baseline_maps: list, presence_value: float, cov_mean_threshold_pct: float, save_fields: bool, verify_jacobian: bool, train_config: dict, ) -> List[str]: """Train the learned deformable atlas and emit atlas NIfTIs. Args: input_dir: Cohort root containing numeric patient subdirectories. atlas_dir: Directory where atlas NIfTI files and ``fields/`` are saved. tractography_relpath: Relative path under each subject to the template-space tractography output. normalization_method: ``NormalizationMethod`` value string. training_space: ``"template_native"`` or ``"affine_native"``. affine_native_relpath: Per-subject affine-native map path; required when ``training_space='affine_native'``. emit_baseline_maps: Subset of ATLAS_STATISTIC_NAMES to write (always includes 'mean'). presence_value: Threshold for the derived ``prob_threshold`` map. cov_mean_threshold_pct: Fraction of the mean max for the ``cov`` map. save_fields: Whether to write per-subject voxel-displacement fields. verify_jacobian: Whether to count non-positive Jacobian voxels per subject. train_config: Device/dtype/model/loss/optimizer scalars. Returns: Absolute paths to generated files (atlas maps in ATLAS_STATISTIC_NAMES order, then per-subject fields, then the metadata JSON). Raises: DependencyError: If torch is not installed. ProcessingError: On empty cohort, shape mismatch, or training failure. """ import gc import json import sys import warnings from pathlib import Path import nibabel as nib import numpy as np from thesis.core.exceptions import DependencyError, ProcessingError try: import torch except ImportError as exc: # pragma: no cover - exercised only without [ml] raise DependencyError( "The 'learned_atlas' workflow requires PyTorch. Install with: " "pip install -e '.[ml]'" ) from exc from thesis.workflows.atlas._io import _build_patient_stack from thesis.workflows.atlas._params import ( ATLAS_FILENAME_MAP, ATLAS_STATISTIC_NAMES, NormalizationMethod, ) from thesis.workflows.atlas._statistics import compute_atlas_statistics from thesis.workflows.learned_atlas.losses import LearnedAtlasLoss from thesis.workflows.learned_atlas.model import LearnedAtlasModel from thesis.workflows.tract_similarity.hcp_loo import _discover_hcp_subjects def _log(message: str) -> None: print(f"[learned_atlas] {message}", file=sys.stderr, flush=True) cfg = dict(train_config) seed = int(cfg.get("seed", 0)) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) requested_device = str(cfg.get("device", "cuda")) if requested_device.startswith("cuda") and not torch.cuda.is_available(): _log(f"CUDA unavailable; falling back to CPU (requested {requested_device!r}).") requested_device = "cpu" if requested_device == "cpu": _log("Training on CPU - expect this to be SLOW for ~145^3 volumes.") device = torch.device(requested_device) dtype_map = { "float16": torch.float16, "float32": torch.float32, "bfloat16": torch.bfloat16, } autocast_dtype = dtype_map[str(cfg.get("dtype", "float32"))] atlas_path = Path(atlas_dir) atlas_path.mkdir(parents=True, exist_ok=True) fields_dir = atlas_path / "fields" if save_fields: fields_dir.mkdir(parents=True, exist_ok=True) norm = NormalizationMethod(normalization_method) cohort_root = Path(input_dir) # Discovery via the LOO helper guarantees aligned (pid, runs) and sorted order. subjects = _discover_hcp_subjects(cohort_root, tractography_relpath) if len(subjects) == 0: raise ProcessingError(f"learned_atlas found no valid cohort subjects under {cohort_root}.") pids = [pid for pid, _ in subjects] patient_inputs = [runs for _, runs in subjects] _log(f"Discovered {len(pids)} cohort subjects (space={training_space}).") # Byte-compat reference: first cohort fdt image's affine/header, exactly like # atlas/compute.py. ref_file: Path = patient_inputs[0][0][0] ref_img = nib.load(str(ref_file)) affine = np.asarray(ref_img.affine) # type: ignore[attr-defined] header = ref_img.header _rshape = ref_img.shape # type: ignore[attr-defined] ref_shape = (int(_rshape[0]), int(_rshape[1]), int(_rshape[2])) _log(f"Reference shape: {ref_shape}; normalization: {norm.value}") with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) if training_space == "template_native": try: stack_np = _build_patient_stack(patient_inputs, ref_shape, norm) except ValueError as exc: raise ProcessingError( f"learned_atlas stack build failed under {cohort_root}: {exc}" ) from exc elif training_space == "affine_native": if not affine_native_relpath: raise ProcessingError( "training_space='affine_native' requires affine_native_relpath." ) stack_np = _build_affine_native_stack( # type: ignore[assignment] cohort_root, pids, affine_native_relpath, ref_shape, _log ) else: raise ProcessingError( f"Unknown training_space {training_space!r}; expected " "'template_native' or 'affine_native'." ) n_subjects = int(stack_np.shape[0]) cohort_mean = stack_np.mean(axis=0).astype(np.float32) n_templates = int(cfg.get("n_templates", 1)) if n_templates > 1: _log( f"n_templates={n_templates}: mixture-of-K is a Phase-2 stub; " "training a single template (K=1)." ) n_templates = 1 subjects_t = torch.from_numpy(stack_np[:, None]).to(device, torch.float32) model = LearnedAtlasModel( volume_shape=ref_shape, n_templates=n_templates, int_steps=int(cfg.get("int_steps", 7)), enc_features=tuple(cfg.get("enc_features", (16, 32, 32, 32))), dec_features=tuple(cfg.get("dec_features", (32, 32, 32, 16, 16))), init_template=torch.from_numpy(cohort_mean), ).to(device) loss_fn = LearnedAtlasLoss( similarity_weight=float(cfg.get("similarity_weight", 1.0)), presence_weight=float(cfg.get("presence_weight", 0.5)), smoothness_weight=float(cfg.get("smoothness_weight", 0.5)), log_offset=float(cfg.get("log_offset", 1e-3)), presence_value=float(presence_value), ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=float(cfg.get("lr", 1e-4))) epochs = int(cfg.get("epochs", 50)) batch = max(1, int(cfg.get("batch_size", 2))) use_autocast = device.type == "cuda" and autocast_dtype != torch.float32 model.train() history: list = [] try: for epoch in range(epochs): perm = torch.randperm(n_subjects, device=device) epoch_total = 0.0 for start in range(0, n_subjects, batch): idx = perm[start : start + batch] target = subjects_t[idx] optimizer.zero_grad(set_to_none=True) if use_autocast: with torch.autocast(device_type="cuda", dtype=autocast_dtype): out = model(target) # Loss (incl. integration-sensitive NCC) computed in float32. loss, comps = loss_fn( out["warped_template"].float(), target, out["velocity"].float() ) else: out = model(target) loss, comps = loss_fn(out["warped_template"], target, out["velocity"]) loss.backward() optimizer.step() epoch_total += comps["total"] * target.shape[0] mean_total = epoch_total / n_subjects history.append({"epoch": epoch, "total": mean_total}) if epoch % max(1, epochs // 10) == 0 or epoch == epochs - 1: _log(f"epoch {epoch + 1}/{epochs} total={mean_total:.6e}") except Exception as exc: # noqa: BLE001 raise ProcessingError(f"learned_atlas training loop failed: {exc}") from exc # Final pass: warped reconstructions, per-subject fields, Jacobian check. model.eval() warped_stack = np.empty((n_subjects, *ref_shape), dtype=np.float32) field_paths: list = [] min_jacobian = float("inf") n_folded = 0 with torch.no_grad(): for i, pid in enumerate(pids): out = model(subjects_t[i : i + 1]) warped_stack[i] = out["warped_template"].squeeze().to("cpu", torch.float32).numpy() if verify_jacobian: det = jacobian_min_and_folds(out["deformation"]) min_jacobian = min(min_jacobian, det[0]) n_folded += det[1] if save_fields: field_np = ( out["deformation"] .squeeze(0) .permute(1, 2, 3, 0) .to("cpu", torch.float32) .numpy() ) field_file = fields_dir / f"{pid}_voxel_displacement.nii.gz" nib.save( nib.Nifti1Image(field_np.astype(np.float32), affine, header), str(field_file) ) field_paths.append(str(field_file)) if verify_jacobian: _log( f"Diffeomorphism check (voxel-space): min Jacobian det={min_jacobian:.4e}, " f"folded voxels={n_folded} (expected 0)." ) # Emit maps: 'mean' = learned template; dispersion maps from warped stack via # the SAME compute_atlas_statistics used by the averaging atlas. template_np = model.export_template()[0].numpy().astype(np.float32) # (X, Y, Z), K=1 derived = compute_atlas_statistics( warped_stack, presence_value=float(presence_value), cov_mean_threshold_pct=float(cov_mean_threshold_pct), ) derived["mean"] = template_np # the sharp learned template is the deliverable mean header.set_data_dtype(np.float32) # type: ignore[attr-defined] generated: List[str] = [] for name in ATLAS_STATISTIC_NAMES: if name not in emit_baseline_maps: continue out_file = atlas_path / ATLAS_FILENAME_MAP[name] nib.save(nib.Nifti1Image(derived[name].astype(np.float32), affine, header), str(out_file)) generated.append(str(out_file)) _log(f"Saved {out_file}") meta = { "n_subjects": n_subjects, "training_space": training_space, "normalization_method": norm.value, "device": requested_device, "dtype": str(cfg.get("dtype", "float32")), "epochs": epochs, "batch_size": batch, "int_steps": int(cfg.get("int_steps", 7)), "min_jacobian_determinant_voxelspace": ( min_jacobian if min_jacobian != float("inf") else None ), "n_neg_jac_voxelspace": n_folded, "final_total_loss": history[-1]["total"] if history else None, "history": history, "subject_ids": pids, "field_convention": "voxel_displacement (X,Y,Z,3); NOT an ANTs/SITK warp", } meta_file = atlas_path / "learned_atlas_training.json" meta_file.write_text(json.dumps(meta, indent=2), encoding="utf-8") generated.extend(field_paths) generated.append(str(meta_file)) # Release VRAM before returning. del model, optimizer, subjects_t, loss_fn gc.collect() if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache() return generated
[docs] def jacobian_min_and_folds(deformation: "object") -> tuple: """Return (min determinant, fold count) for a voxel-space deformation field. Local-import wrapper around the model's voxel-space Jacobian so it pickles into the Nipype subprocess. Operates on a torch tensor. Args: deformation: Tensor ``(1, 3, X, Y, Z)`` of voxel displacements. Returns: ``(min_det, n_negative)`` as ``(float, int)``. """ from thesis.workflows.learned_atlas.model import jacobian_determinant det = jacobian_determinant(deformation) # type: ignore[arg-type] return float(det.min().item()), int((det <= 0).sum().item())
def _build_affine_native_stack( cohort_root: "object", pids: list, relpath: str, ref_shape: tuple, log: "object", ) -> "object": """Stack affine-only-aligned native maps (avoids the double-registration confound). Reads ``<cohort_root>/<pid>/<relpath>`` per subject; ``relpath`` may be a per-tract/bundle path and degrades to the whole-map case. Raises: ProcessingError: If a map is missing or its shape mismatches ``ref_shape``. """ from pathlib import Path import nibabel as nib import numpy as np from thesis.core.exceptions import ProcessingError root = Path(str(cohort_root)) volumes = [] for pid in pids: map_file = root / pid / relpath if not map_file.is_file(): raise ProcessingError(f"Affine-native map not found for {pid}: {map_file}") _obj = nib.load(str(map_file)).dataobj # type: ignore[attr-defined] arr = np.asarray(_obj, dtype=np.float32) if arr.shape[:3] != ref_shape: raise ProcessingError( f"Affine-native map {map_file} shape {arr.shape[:3]} != {ref_shape}" ) volumes.append(arr[..., 0] if arr.ndim == 4 else arr) log(f"Loaded {len(volumes)} affine-native maps from '{relpath}'.") # type: ignore[operator] return np.stack(volumes, axis=0).astype(np.float32)