Learned Atlas Workflow#
Cohort-level workflow that trains a learned conditional deformable
tract-density template (VoxelMorph/AtlasMorph paradigm) and emits the same five
atlas maps as the averaging atlas, byte-compatibly with downstream
atlas_to_patient and tract_similarity consumers.
thesis.workflows.learned_atlas#
Learned conditional deformable tract-density atlas (cohort workflow).
A VoxelMorph/AtlasMorph-style learned template that replaces the voxel-wise
averaging atlas (thesis.workflows.atlas): it trains a single sharp
learnable base template T plus a diffeomorphic registration U-Net on the
cohort’s tract-density (TDI) volumes, then emits the five atlas maps
(atlas_mean.nii.gz et al.) byte-compatibly with the averaging atlas so the
existing atlas_to_patient and tract_similarity paths keep working
unchanged. The network emits only deformation fields (scaling-and-squaring
integration), never density values, so the produced atlas cannot hallucinate
structure.
The model runs ONLY at atlas-build time on the cohort; patients consume the
produced template via the existing atlas_to_patient path.
Registration of the learned_atlas workflow and its learned_atlas config
namespace happens inside .workflow via the @workflow decorator; importing
this package triggers that registration.
thesis.workflows.learned_atlas.workflow#
Nipype workflow definition for the learned conditional deformable atlas.
Owns the cohort-scope @workflow registration, the learned_atlas config
namespace binding, the pre-flight verifier, and the wiring of a single Nipype
Function node (the trainer). The trainer body lives in the sibling
train module; the Function wrapper here imports and calls it so torch stays
out of this module’s import path.
Byte-compatibility: the trainer derives the output affine/header from the first
cohort fdt image exactly as thesis.workflows.atlas.compute() does (NOT
from a pre-existing atlas file), so the learned template is byte-compatible with
the averaging atlas’s atlas_mean.nii.gz on every clean run.
Note on import ordering: config.learned_atlas resolves only after this
package is imported (which registers the namespace). The CLI imports the
selected workflow before loading config, so thesis run -w learned_atlas
works; verify_requirements/build_workflow defensively getattr the namespace.
- thesis.workflows.learned_atlas.workflow.verify_requirements(config, context)[source]#
Pre-flight checks for the
learned_atlasworkflow.Verifies (1) the cohort output directory exists, (2) torch is importable (the
mlextra is installed), (3) the cohort has at leastlearned_atlas.min_subjectssubjects with valid tractography input, and (4) whentraining_space='affine_native'warns that the output is NOT a drop-in for the template-spaceatlas_to_patientjobs.- Parameters:
config (
PipelineConfig) – Fully merged pipeline configuration.context (
ProcessingContext) – Processing context (carriesoutput_dirfor the cohort).
- Return type:
- Returns:
A list of human-readable error strings; empty means ready to build.
- thesis.workflows.learned_atlas.workflow.build_workflow(*, atlas_dir, config, context)[source]#
Build the cohort-level learned-atlas training workflow.
- Parameters:
atlas_dir (
Path) – ResolvedCohortDiroutput (created on resolution); the single source of truth for the output location.config (
PipelineConfig) – Fully merged pipeline configuration.context (
ProcessingContext) – Processing context carrying the cohortoutput_dir.
- Return type:
Workflow- Returns:
A configured
nipype.pipeline.engine.Workflow.- Raises:
ValueError – If
context.output_diris not set.
Learned atlas internals#
Config schema and constants for the learned_atlas cohort workflow.
The LearnedAtlasConfig model is registered as the learned_atlas
config namespace by the @workflow decorator (see workflow.py). Any
learned_atlas: block in a YAML config is validated against this schema at
config-load time and reachable from the workflow body via
getattr(config, "learned_atlas").
The learned (deformable VoxelMorph/AtlasMorph) template remains drop-in
compatible with the averaging atlas: it emits the same
thesis.workflows.atlas._params.ATLAS_STATISTIC_NAMES map set so the
existing atlas_to_patient and tract_similarity paths keep working.
This module depends only on pydantic + thesis.core.config.validators (and
the atlas params for the map-name source of truth); it never imports torch, so
importing the package is safe without the optional ml extra.
- thesis.workflows.learned_atlas._params.DEFAULT_INT_STEPS: int = 7#
VoxelMorph default.
- Type:
Scaling-and-squaring integration steps (2**7 sub-steps)
- thesis.workflows.learned_atlas._params.DEFAULT_LOG_OFFSET: float = 0.001#
Additive offset before log-transforming sparse TDI volumes for the NCC term.
- class thesis.workflows.learned_atlas._params.LearnedAtlasModelConfig[source]#
Bases:
BaseConfigNetwork / template architecture knobs for the learned atlas.
- Variables:
channels – Input channels (1 for a single whole-map / per-bundle TDI volume).
n_templates – Number of learnable base templates
K.1is the single sharp template;>1selects the Phase-2 mixture-of-K stub.int_steps – Scaling-and-squaring integration steps for the diffeomorphic velocity field (guarantees det(Jacobian) > 0).
enc_features – Registration U-Net encoder channel widths (coarse-to-fine).
dec_features – Registration U-Net decoder channel widths.
- Parameters:
data (
Any)
- class thesis.workflows.learned_atlas._params.LearnedAtlasLossConfig[source]#
Bases:
BaseConfigLoss-term weights for learned atlas training.
- Variables:
similarity_weight – Weight on the log-domain local-NCC reconstruction term.
presence_weight – Weight on the presence-weighted soft-Dice term on the true (subject) support; guards against the do-nothing / regression- to-mean failure mode that union-support correlation flatters.
smoothness_weight – Weight on the deformation-field gradient regulariser.
log_offset – Additive offset before the log transform of sparse TDI volumes.
- Parameters:
data (
Any)
- class thesis.workflows.learned_atlas._params.LearnedAtlasOptimizerConfig[source]#
Bases:
BaseConfigOptimizer / training-loop knobs for learned atlas training.
- Variables:
lr – Adam learning rate for the joint template + U-Net optimisation.
epochs – Number of training epochs over the cohort.
batch_size – Subjects per gradient step (bounded by GPU VRAM for ~145^3).
seed – RNG seed for reproducible template init and shuffling.
- Parameters:
data (
Any)
- class thesis.workflows.learned_atlas._params.LearnedAtlasConfig[source]#
Bases:
BaseConfigPydantic schema for the
learned_atlastop-level config key.Registered with the framework by
@workflow(config_namespace="learned_atlas", config_schema=LearnedAtlasConfig); anylearned_atlas:block in a YAML config is type-checked against this model and reachable viagetattr(config, "learned_atlas")from the workflow body.- Variables:
enabled – Master switch for learned atlas training.
training_space –
template_native(already template-registered warped volumes; the only space that is byte-compatible / drop-in foratlas_to_patient) oraffine_native(affine-only-aligned native maps, avoids the double-registration confound but is NOT consumable by the existing template-spaceatlas_to_patientjobs).device – Torch device string (‘cuda’, ‘cuda:0’, ‘cpu’).
dtype – Training/autocast dtype name (‘float32’, ‘float16’, ‘bfloat16’).
affine_native_relpath – Relative path under each patient directory to the affine-only-aligned native map; required when
training_space='affine_native'. May be a per-tract/bundle path and degrades gracefully to the whole-map case.emit_baseline_maps – Averaging-atlas maps to emit. Defaults to the FULL set so the learned dir is a true drop-in for any atlas_to_patient job config; must be a subset of ATLAS_STATISTIC_NAMES and include ‘mean’.
output_subdir – Cohort output subdirectory for the learned template and per-subject fields. Single source of truth for the output location; the @produces CohortDir literal must equal this default.
min_subjects – Minimum cohort size required to train.
save_fields – Write per-subject deformation fields alongside the maps.
verify_jacobian – Verify every predicted field has det(Jacobian) > 0.
model – Network / template architecture knobs.
loss – Loss-term weights.
optimizer – Optimizer / training-loop knobs.
- Parameters:
data (
Any)
- model: LearnedAtlasModelConfig#
- loss: LearnedAtlasLossConfig#
- optimizer: LearnedAtlasOptimizerConfig#
Diffeomorphic registration U-Net and learnable template for the learned atlas.
Compact in-repo VoxelMorph/AtlasMorph model. The network predicts a stationary
velocity field only; scaling-and-squaring integration turns it into a
diffeomorphic deformation, and a spatial transformer resamples a jointly-learned
base template T. The network emits only deformation fields (never density),
so the warped template can relocate existing template mass but cannot
hallucinate new mass.
All tensors are 5D (N, 1, X, Y, Z); the template is (K, 1, X, Y, Z).
torch is imported lazily so importing the package is safe without the optional
ml extra; these classes are only used inside the Nipype Function body.
- thesis.workflows.learned_atlas.model.jacobian_determinant(deformation)[source]#
Voxel-space Jacobian determinant of
phi(x) = x + deformation(x).NOTE: voxel-space only (internal fold-free check). The displacement field is in VOXEL units, so this determinant must NOT be fed to
antsApplyTransformsor compared against ANTs/SimpleITK physical-space warps: physical-space validation requires first composing the field with the reference affine.- Parameters:
deformation (torch.Tensor) – Displacement field
(N, 3, X, Y, Z)in voxel units.- Return type:
torch.Tensor
- Returns:
Determinant
(N, X-1, Y-1, Z-1).- Raises:
DependencyError – When torch is unavailable.
- thesis.workflows.learned_atlas.model.count_negative_jacobians(deformation)[source]#
Return the number of voxels with non-positive (folding) Jacobian det (voxel-space).
- Parameters:
deformation (torch.Tensor)
- Return type:
int
TDI-appropriate losses for learned-atlas training.
Tract-density (TDI / fdt_paths) volumes are heavy-tailed and sparse, so this
module provides:
LocalNCCLogLoss- windowed local NCC onlog(x + offset)density, accumulated in float32 with clamped variances for numerical stability.PresenceWeightedSoftDice- soft Dice weighted toward the subject’s TRUE support (not a union support), guarding against do-nothing models.GradientSmoothness- diffusion regulariser on the velocity field.LearnedAtlasLoss- configurable weighted composite. (No unbounded template-sharpness term: sharpness is left to the data term + the warp.)
torch is imported lazily; these losses are only used inside the Function body.
Learned conditional deformable atlas - training entry + atlas emission.
train_learned_atlas() is called by the learned_atlas Nipype Function
node. It loads the cohort tract-density stack (reusing the averaging-atlas IO
and discovery so subject IDs and stack order stay aligned), trains a learnable
template T plus a deformation-only diffeomorphic network, then emits the
five atlas maps byte-compatibly with the averaging atlas.
Byte-compatibility: the output affine/header are taken from the first cohort
fdt image exactly as thesis.workflows.atlas.compute.generate_statistical_atlas()
does (NOT from a pre-existing atlas file), and mean is the learned template
while std/std_error/cov/prob_threshold come from
thesis.workflows.atlas._statistics.compute_atlas_statistics() over the
warped-template population - identical keys, dtypes (float32) and formulae.
Gotchas: all heavy imports are LOCAL (the body is pickled into a Nipype
subprocess and loguru cannot be pickled); diagnostics use
print(..., file=sys.stderr); torch is the optional ml extra; VRAM is
released before returning.
- thesis.workflows.learned_atlas.train.train_learned_atlas(input_dir, atlas_dir, tractography_relpath, normalization_method, training_space, affine_native_relpath, emit_baseline_maps, presence_value, cov_mean_threshold_pct, save_fields, verify_jacobian, train_config)[source]#
Train the learned deformable atlas and emit atlas NIfTIs.
- Parameters:
input_dir (
str) – Cohort root containing numeric patient subdirectories.atlas_dir (
str) – Directory where atlas NIfTI files andfields/are saved.tractography_relpath (
str) – Relative path under each subject to the template-space tractography output.normalization_method (
str) –NormalizationMethodvalue string.training_space (
str) –"template_native"or"affine_native".affine_native_relpath (
Optional[str]) – Per-subject affine-native map path; required whentraining_space='affine_native'.emit_baseline_maps (
list) – Subset of ATLAS_STATISTIC_NAMES to write (always includes ‘mean’).presence_value (
float) – Threshold for the derivedprob_thresholdmap.cov_mean_threshold_pct (
float) – Fraction of the mean max for thecovmap.save_fields (
bool) – Whether to write per-subject voxel-displacement fields.verify_jacobian (
bool) – Whether to count non-positive Jacobian voxels per subject.train_config (
dict) – Device/dtype/model/loss/optimizer scalars.
- Return type:
- Returns:
Absolute paths to generated files (atlas maps in ATLAS_STATISTIC_NAMES order, then per-subject fields, then the metadata JSON).
- Raises:
DependencyError – If torch is not installed.
ProcessingError – On empty cohort, shape mismatch, or training failure.
- thesis.workflows.learned_atlas.train.jacobian_min_and_folds(deformation)[source]#
Return (min determinant, fold count) for a voxel-space deformation field.
Local-import wrapper around the model’s voxel-space Jacobian so it pickles into the Nipype subprocess. Operates on a torch tensor.
Phase-0 diagnostics and non-circular evaluation metrics for the learned atlas.
Framework-free: every function operates on in-memory numpy arrays (typically the
(n_subjects, X, Y, Z) stack from atlas/_io._build_patient_stack) and
returns plain floats / arrays / JSON-serialisable dicts. No torch, Nipype, or
workflow-orchestration dependencies, mirroring
thesis.workflows.atlas.qc_metrics and
thesis.workflows.tract_similarity._metrics.
Two families:
Phase-0 diagnostics answer “is a learned deformable atlas worth training on this cohort?” before any GPU time is spent.
Non-circular evaluation metrics score
warp(T, field_i)against the held-out subject on its TRUE support (not the union support that flatters do-nothing models), plus a regression-to-mean control.
- thesis.workflows.learned_atlas.diagnostics.residual_variance_map(value_maps)[source]#
Per-voxel residual variance of the cohort against its mean.
- thesis.workflows.learned_atlas.diagnostics.centroid_scatter(value_maps, voxel_size_mm=(1.0, 1.0, 1.0), eps=1e-12)[source]#
Quantify how far per-subject density centroids scatter about their mean.
Large scatter is direct evidence of spatial (not intensity) disagreement - what a learned deformation field is meant to remove.
- thesis.workflows.learned_atlas.diagnostics.occupancy_entropy(subject_masks, threshold=0.0, restrict_to_support=True, eps=1e-12)[source]#
Mean per-voxel binary occupancy entropy across a cohort.
Accepts a DENSITY stack and binarises internally at
threshold(voxels strictly above it count as occupied), so the >0 binarisation is explicit rather than the silent bool-coercion ofbuild_occupancy_map.Entropy ~0 means voxels are consistently on/off (a sharp template is recoverable); ~1 means subjects flip a coin per voxel.
- thesis.workflows.learned_atlas.diagnostics.mass_matched_residual_fraction(value_maps, eps=1e-12)[source]#
Cheap PROXY for the spatially-driven fraction of cohort residual.
Rescales every subject to the cohort-mean total mass and re-measures residual variance; the surviving fraction is a heuristic for disagreement a deformation could move. NOTE: no alignment is performed, so this is a proxy, not a true lower bound - it is advisory in the go/no-go decision.
- thesis.workflows.learned_atlas.diagnostics.phase0_go_no_go(value_maps, voxel_size_mm=(1.0, 1.0, 1.0), subject_mask_threshold=0.0, core_occupancy_threshold=0.5, min_centroid_scatter_mm=1.0, max_occupancy_entropy_bits=0.9)[source]#
Roll the Phase-0 diagnostics into a single go/no-go decision.
The hard gate is: spatial residual (centroid scatter above
min_centroid_scatter_mm) AND a coherent bundle (mean support entropy belowmax_occupancy_entropy_bits). The mass-matched residual fraction is reported as an ADVISORY proxy (not part of the conjunction) until validated.- Return type:
- Returns:
JSON-serialisable dict with
go(bool),reasons, sub-criteria, raw diagnostics, andcore_voxels.- Parameters:
- thesis.workflows.learned_atlas.diagnostics.true_support_mask(truth, threshold=0.0)[source]#
Return the support of the ground-truth subject volume (voxels > threshold).
- thesis.workflows.learned_atlas.diagnostics.delta_on_true_support(prediction, truth, loo_mean, support_threshold=0.0, eps=1e-12)[source]#
Improvement of the prediction over the LOO mean on the true support.
Headline non-circular metric. Restricted to voxels where the held-out subject has real signal, does
warp(T, field_i)reduce error vs the LOO averaging baseline?- Return type:
- Returns:
Dict with
mae_pred,mae_loo_mean,delta_mae(positive = improvement),relative_reduction(delta / mae_loo_mean; range (-inf, 1]),support_voxels. NaN when the support is empty.- Parameters:
- thesis.workflows.learned_atlas.diagnostics.regression_to_mean_control(prediction, truth, loo_mean, support_threshold=0.0, eps=1e-12)[source]#
Detect a model that merely reproduces the cohort mean.
Correlates predicted departure
(prediction - loo_mean)against real departure(truth - loo_mean)on the true support. A do-nothing model (prediction ~= loo_mean) yields near-zero correlation / energy ratio.- Return type:
- Returns:
Dict with
departure_pearson(in [-1, 1]; NaN if constant),predicted_departure_energy,real_departure_energy,departure_energy_ratio(near 0 flags a do-nothing model).- Parameters: