Learned Atlas Workflow#

Cohort-level workflow that trains a learned conditional deformable tract-density template (VoxelMorph/AtlasMorph paradigm) and emits the same five atlas maps as the averaging atlas, byte-compatibly with downstream atlas_to_patient and tract_similarity consumers.

thesis.workflows.learned_atlas#

Learned conditional deformable tract-density atlas (cohort workflow).

A VoxelMorph/AtlasMorph-style learned template that replaces the voxel-wise averaging atlas (thesis.workflows.atlas): it trains a single sharp learnable base template T plus a diffeomorphic registration U-Net on the cohort’s tract-density (TDI) volumes, then emits the five atlas maps (atlas_mean.nii.gz et al.) byte-compatibly with the averaging atlas so the existing atlas_to_patient and tract_similarity paths keep working unchanged. The network emits only deformation fields (scaling-and-squaring integration), never density values, so the produced atlas cannot hallucinate structure.

The model runs ONLY at atlas-build time on the cohort; patients consume the produced template via the existing atlas_to_patient path.

Registration of the learned_atlas workflow and its learned_atlas config namespace happens inside .workflow via the @workflow decorator; importing this package triggers that registration.

thesis.workflows.learned_atlas.workflow#

Nipype workflow definition for the learned conditional deformable atlas.

Owns the cohort-scope @workflow registration, the learned_atlas config namespace binding, the pre-flight verifier, and the wiring of a single Nipype Function node (the trainer). The trainer body lives in the sibling train module; the Function wrapper here imports and calls it so torch stays out of this module’s import path.

Byte-compatibility: the trainer derives the output affine/header from the first cohort fdt image exactly as thesis.workflows.atlas.compute() does (NOT from a pre-existing atlas file), so the learned template is byte-compatible with the averaging atlas’s atlas_mean.nii.gz on every clean run.

Note on import ordering: config.learned_atlas resolves only after this package is imported (which registers the namespace). The CLI imports the selected workflow before loading config, so thesis run -w learned_atlas works; verify_requirements/build_workflow defensively getattr the namespace.

thesis.workflows.learned_atlas.workflow.verify_requirements(config, context)[source]#

Pre-flight checks for the learned_atlas workflow.

Verifies (1) the cohort output directory exists, (2) torch is importable (the ml extra is installed), (3) the cohort has at least learned_atlas.min_subjects subjects with valid tractography input, and (4) when training_space='affine_native' warns that the output is NOT a drop-in for the template-space atlas_to_patient jobs.

Parameters:
Return type:

List[str]

Returns:

A list of human-readable error strings; empty means ready to build.

thesis.workflows.learned_atlas.workflow.build_workflow(*, atlas_dir, config, context)[source]#

Build the cohort-level learned-atlas training workflow.

Parameters:
  • atlas_dir (Path) – Resolved CohortDir output (created on resolution); the single source of truth for the output location.

  • config (PipelineConfig) – Fully merged pipeline configuration.

  • context (ProcessingContext) – Processing context carrying the cohort output_dir.

Return type:

Workflow

Returns:

A configured nipype.pipeline.engine.Workflow.

Raises:

ValueError – If context.output_dir is not set.

Learned atlas internals#

Config schema and constants for the learned_atlas cohort workflow.

The LearnedAtlasConfig model is registered as the learned_atlas config namespace by the @workflow decorator (see workflow.py). Any learned_atlas: block in a YAML config is validated against this schema at config-load time and reachable from the workflow body via getattr(config, "learned_atlas").

The learned (deformable VoxelMorph/AtlasMorph) template remains drop-in compatible with the averaging atlas: it emits the same thesis.workflows.atlas._params.ATLAS_STATISTIC_NAMES map set so the existing atlas_to_patient and tract_similarity paths keep working.

This module depends only on pydantic + thesis.core.config.validators (and the atlas params for the map-name source of truth); it never imports torch, so importing the package is safe without the optional ml extra.

thesis.workflows.learned_atlas._params.DEFAULT_INT_STEPS: int = 7#

VoxelMorph default.

Type:

Scaling-and-squaring integration steps (2**7 sub-steps)

thesis.workflows.learned_atlas._params.DEFAULT_LOG_OFFSET: float = 0.001#

Additive offset before log-transforming sparse TDI volumes for the NCC term.

class thesis.workflows.learned_atlas._params.LearnedAtlasModelConfig[source]#

Bases: BaseConfig

Network / template architecture knobs for the learned atlas.

Variables:
  • channels – Input channels (1 for a single whole-map / per-bundle TDI volume).

  • n_templates – Number of learnable base templates K. 1 is the single sharp template; >1 selects the Phase-2 mixture-of-K stub.

  • int_steps – Scaling-and-squaring integration steps for the diffeomorphic velocity field (guarantees det(Jacobian) > 0).

  • enc_features – Registration U-Net encoder channel widths (coarse-to-fine).

  • dec_features – Registration U-Net decoder channel widths.

Parameters:

data (Any)

channels: int#
n_templates: int#
int_steps: int#
enc_features: List[int]#
dec_features: List[int]#
model_config: ClassVar[ConfigDict] = {'extra': 'forbid'}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class thesis.workflows.learned_atlas._params.LearnedAtlasLossConfig[source]#

Bases: BaseConfig

Loss-term weights for learned atlas training.

Variables:
  • similarity_weight – Weight on the log-domain local-NCC reconstruction term.

  • presence_weight – Weight on the presence-weighted soft-Dice term on the true (subject) support; guards against the do-nothing / regression- to-mean failure mode that union-support correlation flatters.

  • smoothness_weight – Weight on the deformation-field gradient regulariser.

  • log_offset – Additive offset before the log transform of sparse TDI volumes.

Parameters:

data (Any)

similarity_weight: float#
presence_weight: float#
smoothness_weight: float#
log_offset: float#
model_config: ClassVar[ConfigDict] = {'extra': 'forbid'}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class thesis.workflows.learned_atlas._params.LearnedAtlasOptimizerConfig[source]#

Bases: BaseConfig

Optimizer / training-loop knobs for learned atlas training.

Variables:
  • lr – Adam learning rate for the joint template + U-Net optimisation.

  • epochs – Number of training epochs over the cohort.

  • batch_size – Subjects per gradient step (bounded by GPU VRAM for ~145^3).

  • seed – RNG seed for reproducible template init and shuffling.

Parameters:

data (Any)

lr: float#
epochs: int#
batch_size: int#
seed: int#
model_config: ClassVar[ConfigDict] = {'extra': 'forbid'}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class thesis.workflows.learned_atlas._params.LearnedAtlasConfig[source]#

Bases: BaseConfig

Pydantic schema for the learned_atlas top-level config key.

Registered with the framework by @workflow(config_namespace="learned_atlas", config_schema=LearnedAtlasConfig); any learned_atlas: block in a YAML config is type-checked against this model and reachable via getattr(config, "learned_atlas") from the workflow body.

Variables:
  • enabled – Master switch for learned atlas training.

  • training_spacetemplate_native (already template-registered warped volumes; the only space that is byte-compatible / drop-in for atlas_to_patient) or affine_native (affine-only-aligned native maps, avoids the double-registration confound but is NOT consumable by the existing template-space atlas_to_patient jobs).

  • device – Torch device string (‘cuda’, ‘cuda:0’, ‘cpu’).

  • dtype – Training/autocast dtype name (‘float32’, ‘float16’, ‘bfloat16’).

  • affine_native_relpath – Relative path under each patient directory to the affine-only-aligned native map; required when training_space='affine_native'. May be a per-tract/bundle path and degrades gracefully to the whole-map case.

  • emit_baseline_maps – Averaging-atlas maps to emit. Defaults to the FULL set so the learned dir is a true drop-in for any atlas_to_patient job config; must be a subset of ATLAS_STATISTIC_NAMES and include ‘mean’.

  • output_subdir – Cohort output subdirectory for the learned template and per-subject fields. Single source of truth for the output location; the @produces CohortDir literal must equal this default.

  • min_subjects – Minimum cohort size required to train.

  • save_fields – Write per-subject deformation fields alongside the maps.

  • verify_jacobian – Verify every predicted field has det(Jacobian) > 0.

  • model – Network / template architecture knobs.

  • loss – Loss-term weights.

  • optimizer – Optimizer / training-loop knobs.

Parameters:

data (Any)

enabled: bool#
training_space: Literal['template_native', 'affine_native']#
device: str#
dtype: str#
affine_native_relpath: str | None#
emit_baseline_maps: List[str]#
output_subdir: str#
min_subjects: int#
save_fields: bool#
verify_jacobian: bool#
model: LearnedAtlasModelConfig#
loss: LearnedAtlasLossConfig#
optimizer: LearnedAtlasOptimizerConfig#
model_config: ClassVar[ConfigDict] = {'extra': 'forbid'}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Diffeomorphic registration U-Net and learnable template for the learned atlas.

Compact in-repo VoxelMorph/AtlasMorph model. The network predicts a stationary velocity field only; scaling-and-squaring integration turns it into a diffeomorphic deformation, and a spatial transformer resamples a jointly-learned base template T. The network emits only deformation fields (never density), so the warped template can relocate existing template mass but cannot hallucinate new mass.

All tensors are 5D (N, 1, X, Y, Z); the template is (K, 1, X, Y, Z). torch is imported lazily so importing the package is safe without the optional ml extra; these classes are only used inside the Nipype Function body.

thesis.workflows.learned_atlas.model.jacobian_determinant(deformation)[source]#

Voxel-space Jacobian determinant of phi(x) = x + deformation(x).

NOTE: voxel-space only (internal fold-free check). The displacement field is in VOXEL units, so this determinant must NOT be fed to antsApplyTransforms or compared against ANTs/SimpleITK physical-space warps: physical-space validation requires first composing the field with the reference affine.

Parameters:

deformation (torch.Tensor) – Displacement field (N, 3, X, Y, Z) in voxel units.

Return type:

torch.Tensor

Returns:

Determinant (N, X-1, Y-1, Z-1).

Raises:

DependencyError – When torch is unavailable.

thesis.workflows.learned_atlas.model.count_negative_jacobians(deformation)[source]#

Return the number of voxels with non-positive (folding) Jacobian det (voxel-space).

Parameters:

deformation (torch.Tensor)

Return type:

int

TDI-appropriate losses for learned-atlas training.

Tract-density (TDI / fdt_paths) volumes are heavy-tailed and sparse, so this module provides:

  • LocalNCCLogLoss - windowed local NCC on log(x + offset) density, accumulated in float32 with clamped variances for numerical stability.

  • PresenceWeightedSoftDice - soft Dice weighted toward the subject’s TRUE support (not a union support), guarding against do-nothing models.

  • GradientSmoothness - diffusion regulariser on the velocity field.

  • LearnedAtlasLoss - configurable weighted composite. (No unbounded template-sharpness term: sharpness is left to the data term + the warp.)

torch is imported lazily; these losses are only used inside the Function body.

Learned conditional deformable atlas - training entry + atlas emission.

train_learned_atlas() is called by the learned_atlas Nipype Function node. It loads the cohort tract-density stack (reusing the averaging-atlas IO and discovery so subject IDs and stack order stay aligned), trains a learnable template T plus a deformation-only diffeomorphic network, then emits the five atlas maps byte-compatibly with the averaging atlas.

Byte-compatibility: the output affine/header are taken from the first cohort fdt image exactly as thesis.workflows.atlas.compute.generate_statistical_atlas() does (NOT from a pre-existing atlas file), and mean is the learned template while std/std_error/cov/prob_threshold come from thesis.workflows.atlas._statistics.compute_atlas_statistics() over the warped-template population - identical keys, dtypes (float32) and formulae.

Gotchas: all heavy imports are LOCAL (the body is pickled into a Nipype subprocess and loguru cannot be pickled); diagnostics use print(..., file=sys.stderr); torch is the optional ml extra; VRAM is released before returning.

thesis.workflows.learned_atlas.train.train_learned_atlas(input_dir, atlas_dir, tractography_relpath, normalization_method, training_space, affine_native_relpath, emit_baseline_maps, presence_value, cov_mean_threshold_pct, save_fields, verify_jacobian, train_config)[source]#

Train the learned deformable atlas and emit atlas NIfTIs.

Parameters:
  • input_dir (str) – Cohort root containing numeric patient subdirectories.

  • atlas_dir (str) – Directory where atlas NIfTI files and fields/ are saved.

  • tractography_relpath (str) – Relative path under each subject to the template-space tractography output.

  • normalization_method (str) – NormalizationMethod value string.

  • training_space (str) – "template_native" or "affine_native".

  • affine_native_relpath (Optional[str]) – Per-subject affine-native map path; required when training_space='affine_native'.

  • emit_baseline_maps (list) – Subset of ATLAS_STATISTIC_NAMES to write (always includes ‘mean’).

  • presence_value (float) – Threshold for the derived prob_threshold map.

  • cov_mean_threshold_pct (float) – Fraction of the mean max for the cov map.

  • save_fields (bool) – Whether to write per-subject voxel-displacement fields.

  • verify_jacobian (bool) – Whether to count non-positive Jacobian voxels per subject.

  • train_config (dict) – Device/dtype/model/loss/optimizer scalars.

Return type:

List[str]

Returns:

Absolute paths to generated files (atlas maps in ATLAS_STATISTIC_NAMES order, then per-subject fields, then the metadata JSON).

Raises:
thesis.workflows.learned_atlas.train.jacobian_min_and_folds(deformation)[source]#

Return (min determinant, fold count) for a voxel-space deformation field.

Local-import wrapper around the model’s voxel-space Jacobian so it pickles into the Nipype subprocess. Operates on a torch tensor.

Parameters:

deformation (object) – Tensor (1, 3, X, Y, Z) of voxel displacements.

Return type:

tuple

Returns:

(min_det, n_negative) as (float, int).

Phase-0 diagnostics and non-circular evaluation metrics for the learned atlas.

Framework-free: every function operates on in-memory numpy arrays (typically the (n_subjects, X, Y, Z) stack from atlas/_io._build_patient_stack) and returns plain floats / arrays / JSON-serialisable dicts. No torch, Nipype, or workflow-orchestration dependencies, mirroring thesis.workflows.atlas.qc_metrics and thesis.workflows.tract_similarity._metrics.

Two families:

  • Phase-0 diagnostics answer “is a learned deformable atlas worth training on this cohort?” before any GPU time is spent.

  • Non-circular evaluation metrics score warp(T, field_i) against the held-out subject on its TRUE support (not the union support that flatters do-nothing models), plus a regression-to-mean control.

thesis.workflows.learned_atlas.diagnostics.residual_variance_map(value_maps)[source]#

Per-voxel residual variance of the cohort against its mean.

Parameters:

value_maps (ndarray)

Return type:

ndarray

thesis.workflows.learned_atlas.diagnostics.centroid_scatter(value_maps, voxel_size_mm=(1.0, 1.0, 1.0), eps=1e-12)[source]#

Quantify how far per-subject density centroids scatter about their mean.

Large scatter is direct evidence of spatial (not intensity) disagreement - what a learned deformation field is meant to remove.

Return type:

dict[str, float]

Returns:

Dict with mean_distance_mm/rms_distance_mm/max_distance_mm and n_subjects used (empty-mass subjects excluded).

Parameters:
thesis.workflows.learned_atlas.diagnostics.occupancy_entropy(subject_masks, threshold=0.0, restrict_to_support=True, eps=1e-12)[source]#

Mean per-voxel binary occupancy entropy across a cohort.

Accepts a DENSITY stack and binarises internally at threshold (voxels strictly above it count as occupied), so the >0 binarisation is explicit rather than the silent bool-coercion of build_occupancy_map.

Entropy ~0 means voxels are consistently on/off (a sharp template is recoverable); ~1 means subjects flip a coin per voxel.

Return type:

dict[str, float]

Returns:

Dict with mean_entropy_bits, max_entropy_bits, support_voxels.

Parameters:
thesis.workflows.learned_atlas.diagnostics.mass_matched_residual_fraction(value_maps, eps=1e-12)[source]#

Cheap PROXY for the spatially-driven fraction of cohort residual.

Rescales every subject to the cohort-mean total mass and re-measures residual variance; the surviving fraction is a heuristic for disagreement a deformation could move. NOTE: no alignment is performed, so this is a proxy, not a true lower bound - it is advisory in the go/no-go decision.

Return type:

float

Returns:

Float in [0, 1]; 0.0 when the cohort residual is negligible.

Parameters:
thesis.workflows.learned_atlas.diagnostics.phase0_go_no_go(value_maps, voxel_size_mm=(1.0, 1.0, 1.0), subject_mask_threshold=0.0, core_occupancy_threshold=0.5, min_centroid_scatter_mm=1.0, max_occupancy_entropy_bits=0.9)[source]#

Roll the Phase-0 diagnostics into a single go/no-go decision.

The hard gate is: spatial residual (centroid scatter above min_centroid_scatter_mm) AND a coherent bundle (mean support entropy below max_occupancy_entropy_bits). The mass-matched residual fraction is reported as an ADVISORY proxy (not part of the conjunction) until validated.

Return type:

dict[str, object]

Returns:

JSON-serialisable dict with go (bool), reasons, sub-criteria, raw diagnostics, and core_voxels.

Parameters:
thesis.workflows.learned_atlas.diagnostics.true_support_mask(truth, threshold=0.0)[source]#

Return the support of the ground-truth subject volume (voxels > threshold).

Parameters:
Return type:

ndarray

thesis.workflows.learned_atlas.diagnostics.delta_on_true_support(prediction, truth, loo_mean, support_threshold=0.0, eps=1e-12)[source]#

Improvement of the prediction over the LOO mean on the true support.

Headline non-circular metric. Restricted to voxels where the held-out subject has real signal, does warp(T, field_i) reduce error vs the LOO averaging baseline?

Return type:

dict[str, float]

Returns:

Dict with mae_pred, mae_loo_mean, delta_mae (positive = improvement), relative_reduction (delta / mae_loo_mean; range (-inf, 1]), support_voxels. NaN when the support is empty.

Parameters:
thesis.workflows.learned_atlas.diagnostics.regression_to_mean_control(prediction, truth, loo_mean, support_threshold=0.0, eps=1e-12)[source]#

Detect a model that merely reproduces the cohort mean.

Correlates predicted departure (prediction - loo_mean) against real departure (truth - loo_mean) on the true support. A do-nothing model (prediction ~= loo_mean) yields near-zero correlation / energy ratio.

Return type:

dict[str, float]

Returns:

Dict with departure_pearson (in [-1, 1]; NaN if constant), predicted_departure_energy, real_departure_energy, departure_energy_ratio (near 0 flags a do-nothing model).

Parameters:
thesis.workflows.learned_atlas.diagnostics.evaluate_prediction(prediction, truth, loo_mean, support_threshold=0.0)[source]#

Bundle the non-circular metrics for one held-out subject.

Return type:

dict[str, dict[str, float]]

Returns:

Dict with delta_on_true_support and regression_to_mean sub-dicts.

Parameters: