"""Learned conditional deformable atlas - training entry + atlas emission.
:func:`train_learned_atlas` is called by the ``learned_atlas`` Nipype Function
node. It loads the cohort tract-density stack (reusing the averaging-atlas IO
and discovery so subject IDs and stack order stay aligned), trains a learnable
template ``T`` plus a deformation-only diffeomorphic network, then emits the
five atlas maps byte-compatibly with the averaging atlas.
Byte-compatibility: the output affine/header are taken from the first cohort
fdt image exactly as :func:`thesis.workflows.atlas.compute.generate_statistical_atlas`
does (NOT from a pre-existing atlas file), and ``mean`` is the learned template
while std/std_error/cov/prob_threshold come from
:func:`thesis.workflows.atlas._statistics.compute_atlas_statistics` over the
warped-template population - identical keys, dtypes (float32) and formulae.
Gotchas: all heavy imports are LOCAL (the body is pickled into a Nipype
subprocess and loguru cannot be pickled); diagnostics use
``print(..., file=sys.stderr)``; torch is the optional ``ml`` extra; VRAM is
released before returning.
"""
from __future__ import annotations
from typing import List, Optional
[docs]
def train_learned_atlas(
input_dir: str,
atlas_dir: str,
tractography_relpath: str,
normalization_method: str,
training_space: str,
affine_native_relpath: Optional[str],
emit_baseline_maps: list,
presence_value: float,
cov_mean_threshold_pct: float,
save_fields: bool,
verify_jacobian: bool,
train_config: dict,
) -> List[str]:
"""Train the learned deformable atlas and emit atlas NIfTIs.
Args:
input_dir: Cohort root containing numeric patient subdirectories.
atlas_dir: Directory where atlas NIfTI files and ``fields/`` are saved.
tractography_relpath: Relative path under each subject to the
template-space tractography output.
normalization_method: ``NormalizationMethod`` value string.
training_space: ``"template_native"`` or ``"affine_native"``.
affine_native_relpath: Per-subject affine-native map path; required when
``training_space='affine_native'``.
emit_baseline_maps: Subset of ATLAS_STATISTIC_NAMES to write (always
includes 'mean').
presence_value: Threshold for the derived ``prob_threshold`` map.
cov_mean_threshold_pct: Fraction of the mean max for the ``cov`` map.
save_fields: Whether to write per-subject voxel-displacement fields.
verify_jacobian: Whether to count non-positive Jacobian voxels per subject.
train_config: Device/dtype/model/loss/optimizer scalars.
Returns:
Absolute paths to generated files (atlas maps in ATLAS_STATISTIC_NAMES
order, then per-subject fields, then the metadata JSON).
Raises:
DependencyError: If torch is not installed.
ProcessingError: On empty cohort, shape mismatch, or training failure.
"""
import gc
import json
import sys
import warnings
from pathlib import Path
import nibabel as nib
import numpy as np
from thesis.core.exceptions import DependencyError, ProcessingError
try:
import torch
except ImportError as exc: # pragma: no cover - exercised only without [ml]
raise DependencyError(
"The 'learned_atlas' workflow requires PyTorch. Install with: " "pip install -e '.[ml]'"
) from exc
from thesis.workflows.atlas._io import _build_patient_stack
from thesis.workflows.atlas._params import (
ATLAS_FILENAME_MAP,
ATLAS_STATISTIC_NAMES,
NormalizationMethod,
)
from thesis.workflows.atlas._statistics import compute_atlas_statistics
from thesis.workflows.learned_atlas.losses import LearnedAtlasLoss
from thesis.workflows.learned_atlas.model import LearnedAtlasModel
from thesis.workflows.tract_similarity.hcp_loo import _discover_hcp_subjects
def _log(message: str) -> None:
print(f"[learned_atlas] {message}", file=sys.stderr, flush=True)
cfg = dict(train_config)
seed = int(cfg.get("seed", 0))
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
requested_device = str(cfg.get("device", "cuda"))
if requested_device.startswith("cuda") and not torch.cuda.is_available():
_log(f"CUDA unavailable; falling back to CPU (requested {requested_device!r}).")
requested_device = "cpu"
if requested_device == "cpu":
_log("Training on CPU - expect this to be SLOW for ~145^3 volumes.")
device = torch.device(requested_device)
dtype_map = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
autocast_dtype = dtype_map[str(cfg.get("dtype", "float32"))]
atlas_path = Path(atlas_dir)
atlas_path.mkdir(parents=True, exist_ok=True)
fields_dir = atlas_path / "fields"
if save_fields:
fields_dir.mkdir(parents=True, exist_ok=True)
norm = NormalizationMethod(normalization_method)
cohort_root = Path(input_dir)
# Discovery via the LOO helper guarantees aligned (pid, runs) and sorted order.
subjects = _discover_hcp_subjects(cohort_root, tractography_relpath)
if len(subjects) == 0:
raise ProcessingError(f"learned_atlas found no valid cohort subjects under {cohort_root}.")
pids = [pid for pid, _ in subjects]
patient_inputs = [runs for _, runs in subjects]
_log(f"Discovered {len(pids)} cohort subjects (space={training_space}).")
# Byte-compat reference: first cohort fdt image's affine/header, exactly like
# atlas/compute.py.
ref_file: Path = patient_inputs[0][0][0]
ref_img = nib.load(str(ref_file))
affine = np.asarray(ref_img.affine) # type: ignore[attr-defined]
header = ref_img.header
_rshape = ref_img.shape # type: ignore[attr-defined]
ref_shape = (int(_rshape[0]), int(_rshape[1]), int(_rshape[2]))
_log(f"Reference shape: {ref_shape}; normalization: {norm.value}")
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)
if training_space == "template_native":
try:
stack_np = _build_patient_stack(patient_inputs, ref_shape, norm)
except ValueError as exc:
raise ProcessingError(
f"learned_atlas stack build failed under {cohort_root}: {exc}"
) from exc
elif training_space == "affine_native":
if not affine_native_relpath:
raise ProcessingError(
"training_space='affine_native' requires affine_native_relpath."
)
stack_np = _build_affine_native_stack( # type: ignore[assignment]
cohort_root, pids, affine_native_relpath, ref_shape, _log
)
else:
raise ProcessingError(
f"Unknown training_space {training_space!r}; expected "
"'template_native' or 'affine_native'."
)
n_subjects = int(stack_np.shape[0])
cohort_mean = stack_np.mean(axis=0).astype(np.float32)
n_templates = int(cfg.get("n_templates", 1))
if n_templates > 1:
_log(
f"n_templates={n_templates}: mixture-of-K is a Phase-2 stub; "
"training a single template (K=1)."
)
n_templates = 1
subjects_t = torch.from_numpy(stack_np[:, None]).to(device, torch.float32)
model = LearnedAtlasModel(
volume_shape=ref_shape,
n_templates=n_templates,
int_steps=int(cfg.get("int_steps", 7)),
enc_features=tuple(cfg.get("enc_features", (16, 32, 32, 32))),
dec_features=tuple(cfg.get("dec_features", (32, 32, 32, 16, 16))),
init_template=torch.from_numpy(cohort_mean),
).to(device)
loss_fn = LearnedAtlasLoss(
similarity_weight=float(cfg.get("similarity_weight", 1.0)),
presence_weight=float(cfg.get("presence_weight", 0.5)),
smoothness_weight=float(cfg.get("smoothness_weight", 0.5)),
log_offset=float(cfg.get("log_offset", 1e-3)),
presence_value=float(presence_value),
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=float(cfg.get("lr", 1e-4)))
epochs = int(cfg.get("epochs", 50))
batch = max(1, int(cfg.get("batch_size", 2)))
use_autocast = device.type == "cuda" and autocast_dtype != torch.float32
model.train()
history: list = []
try:
for epoch in range(epochs):
perm = torch.randperm(n_subjects, device=device)
epoch_total = 0.0
for start in range(0, n_subjects, batch):
idx = perm[start : start + batch]
target = subjects_t[idx]
optimizer.zero_grad(set_to_none=True)
if use_autocast:
with torch.autocast(device_type="cuda", dtype=autocast_dtype):
out = model(target)
# Loss (incl. integration-sensitive NCC) computed in float32.
loss, comps = loss_fn(
out["warped_template"].float(), target, out["velocity"].float()
)
else:
out = model(target)
loss, comps = loss_fn(out["warped_template"], target, out["velocity"])
loss.backward()
optimizer.step()
epoch_total += comps["total"] * target.shape[0]
mean_total = epoch_total / n_subjects
history.append({"epoch": epoch, "total": mean_total})
if epoch % max(1, epochs // 10) == 0 or epoch == epochs - 1:
_log(f"epoch {epoch + 1}/{epochs} total={mean_total:.6e}")
except Exception as exc: # noqa: BLE001
raise ProcessingError(f"learned_atlas training loop failed: {exc}") from exc
# Final pass: warped reconstructions, per-subject fields, Jacobian check.
model.eval()
warped_stack = np.empty((n_subjects, *ref_shape), dtype=np.float32)
field_paths: list = []
min_jacobian = float("inf")
n_folded = 0
with torch.no_grad():
for i, pid in enumerate(pids):
out = model(subjects_t[i : i + 1])
warped_stack[i] = out["warped_template"].squeeze().to("cpu", torch.float32).numpy()
if verify_jacobian:
det = jacobian_min_and_folds(out["deformation"])
min_jacobian = min(min_jacobian, det[0])
n_folded += det[1]
if save_fields:
field_np = (
out["deformation"]
.squeeze(0)
.permute(1, 2, 3, 0)
.to("cpu", torch.float32)
.numpy()
)
field_file = fields_dir / f"{pid}_voxel_displacement.nii.gz"
nib.save(
nib.Nifti1Image(field_np.astype(np.float32), affine, header), str(field_file)
)
field_paths.append(str(field_file))
if verify_jacobian:
_log(
f"Diffeomorphism check (voxel-space): min Jacobian det={min_jacobian:.4e}, "
f"folded voxels={n_folded} (expected 0)."
)
# Emit maps: 'mean' = learned template; dispersion maps from warped stack via
# the SAME compute_atlas_statistics used by the averaging atlas.
template_np = model.export_template()[0].numpy().astype(np.float32) # (X, Y, Z), K=1
derived = compute_atlas_statistics(
warped_stack,
presence_value=float(presence_value),
cov_mean_threshold_pct=float(cov_mean_threshold_pct),
)
derived["mean"] = template_np # the sharp learned template is the deliverable mean
header.set_data_dtype(np.float32) # type: ignore[attr-defined]
generated: List[str] = []
for name in ATLAS_STATISTIC_NAMES:
if name not in emit_baseline_maps:
continue
out_file = atlas_path / ATLAS_FILENAME_MAP[name]
nib.save(nib.Nifti1Image(derived[name].astype(np.float32), affine, header), str(out_file))
generated.append(str(out_file))
_log(f"Saved {out_file}")
meta = {
"n_subjects": n_subjects,
"training_space": training_space,
"normalization_method": norm.value,
"device": requested_device,
"dtype": str(cfg.get("dtype", "float32")),
"epochs": epochs,
"batch_size": batch,
"int_steps": int(cfg.get("int_steps", 7)),
"min_jacobian_determinant_voxelspace": (
min_jacobian if min_jacobian != float("inf") else None
),
"n_neg_jac_voxelspace": n_folded,
"final_total_loss": history[-1]["total"] if history else None,
"history": history,
"subject_ids": pids,
"field_convention": "voxel_displacement (X,Y,Z,3); NOT an ANTs/SITK warp",
}
meta_file = atlas_path / "learned_atlas_training.json"
meta_file.write_text(json.dumps(meta, indent=2), encoding="utf-8")
generated.extend(field_paths)
generated.append(str(meta_file))
# Release VRAM before returning.
del model, optimizer, subjects_t, loss_fn
gc.collect()
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
return generated
[docs]
def jacobian_min_and_folds(deformation: "object") -> tuple:
"""Return (min determinant, fold count) for a voxel-space deformation field.
Local-import wrapper around the model's voxel-space Jacobian so it pickles
into the Nipype subprocess. Operates on a torch tensor.
Args:
deformation: Tensor ``(1, 3, X, Y, Z)`` of voxel displacements.
Returns:
``(min_det, n_negative)`` as ``(float, int)``.
"""
from thesis.workflows.learned_atlas.model import jacobian_determinant
det = jacobian_determinant(deformation) # type: ignore[arg-type]
return float(det.min().item()), int((det <= 0).sum().item())
def _build_affine_native_stack(
cohort_root: "object",
pids: list,
relpath: str,
ref_shape: tuple,
log: "object",
) -> "object":
"""Stack affine-only-aligned native maps (avoids the double-registration confound).
Reads ``<cohort_root>/<pid>/<relpath>`` per subject; ``relpath`` may be a
per-tract/bundle path and degrades to the whole-map case.
Raises:
ProcessingError: If a map is missing or its shape mismatches ``ref_shape``.
"""
from pathlib import Path
import nibabel as nib
import numpy as np
from thesis.core.exceptions import ProcessingError
root = Path(str(cohort_root))
volumes = []
for pid in pids:
map_file = root / pid / relpath
if not map_file.is_file():
raise ProcessingError(f"Affine-native map not found for {pid}: {map_file}")
_obj = nib.load(str(map_file)).dataobj # type: ignore[attr-defined]
arr = np.asarray(_obj, dtype=np.float32)
if arr.shape[:3] != ref_shape:
raise ProcessingError(
f"Affine-native map {map_file} shape {arr.shape[:3]} != {ref_shape}"
)
volumes.append(arr[..., 0] if arr.ndim == 4 else arr)
log(f"Loaded {len(volumes)} affine-native maps from '{relpath}'.") # type: ignore[operator]
return np.stack(volumes, axis=0).astype(np.float32)