learned_atlas — learned deformable tract-density template#
Schema: LearnedAtlasConfig in src/thesis/workflows/learned_atlas/_params.py. Unlike the core sections, this namespace is owned by the learned_atlas workflow (registered via @workflow(config_namespace="learned_atlas", config_schema=LearnedAtlasConfig)), so a learned_atlas: block is validated against this model only when the workflow is on the registry path. It is reachable from the workflow body via getattr(config, "learned_atlas").
The learned_atlas workflow is cohort-level — -p/--patient-id and --all are ignored; it scans the configured output directory for per-subject tractography. It trains a deformable (VoxelMorph/AtlasMorph) tract-density template and emits the same five maps as the averaging atlas (atlas_mean.nii.gz, atlas_std.nii.gz, atlas_std_error.nii.gz, atlas_cov.nii.gz, atlas_prob_threshold.nii.gz), so the existing atlas_to_patient and tract_similarity paths keep working unchanged.
Training requires PyTorch, which ships in the optional ml extra (pip install -e ".[ml]"). torch is lazy-imported only inside the training node, so importing the package — and validating this config — never needs the extra.
thesis run -w learned_atlas -c default --protocol hcp
Top-level fields#
Field |
Type |
Default |
Constraints |
Description |
|---|---|---|---|---|
|
|
|
— |
Master switch for learned atlas training. |
|
|
|
one of the two literals |
|
|
|
|
must match |
Torch device for training ( |
|
|
|
one of |
Training / autocast dtype. |
|
|
|
required when |
Relative path under each patient directory to the affine-only-aligned native map. May be a per-tract/bundle path and degrades gracefully to the whole-map case. |
|
|
full atlas-map set ( |
must be a subset of the atlas maps and include |
Averaging-atlas maps to emit from the learned template. Defaults to the full set so the output dir is a true drop-in for any |
|
|
|
must match the |
Cohort output subdirectory for the learned template ( |
|
|
|
|
Minimum cohort size required to train. |
|
|
|
— |
Write per-subject deformation fields alongside the maps. |
|
|
|
— |
Verify every predicted field has det(Jacobian) > 0 (no folding). |
|
|
(see below) |
— |
Network / template architecture knobs. |
|
|
(see below) |
— |
Loss-term weights. |
|
|
(see below) |
— |
Optimizer / training-loop knobs. |
The mean map (atlas_mean.nii.gz) is the learned template consumed by atlas_to_patient, which is why emit_baseline_maps must always include it.
model — LearnedAtlasModelConfig#
Field |
Type |
Default |
Constraints |
Description |
|---|---|---|---|---|
|
|
|
|
Input channels (1 for a single whole-map / per-bundle TDI volume). |
|
|
|
|
Number of learnable base templates |
|
|
|
|
Scaling-and-squaring integration steps for the diffeomorphic warp. Higher = stronger invertibility (det(Jac) > 0), more memory. |
|
|
|
non-empty; each |
Registration U-Net encoder channel widths (coarse-to-fine). |
|
|
|
non-empty; each |
Registration U-Net decoder channel widths. |
loss — LearnedAtlasLossConfig#
Field |
Type |
Default |
Constraints |
Description |
|---|---|---|---|---|
|
|
|
|
Weight on the log-domain local-NCC reconstruction term. |
|
|
|
|
Weight on the presence-weighted soft-Dice term on the subject’s true support; guards against the regression-to-mean failure mode that union-support correlation flatters. |
|
|
|
|
Weight on the deformation-field gradient (smoothness) regulariser. |
|
|
|
|
Additive offset before log-transforming sparse TDI volumes for the NCC term. |
optimizer — LearnedAtlasOptimizerConfig#
Field |
Type |
Default |
Constraints |
Description |
|---|---|---|---|---|
|
|
|
|
Adam learning rate for the joint template + U-Net optimisation. |
|
|
|
|
Training epochs over the cohort. |
|
|
|
|
Subjects per gradient step (bounded by GPU VRAM for ~145³ volumes). |
|
|
|
|
RNG seed for reproducible template init and shuffling. |
Example#
learned_atlas:
enabled: true
training_space: template_native
device: cuda:0
dtype: float32
output_subdir: learned_atlas
min_subjects: 5
save_fields: true
verify_jacobian: true
model:
n_templates: 1
int_steps: 7
loss:
similarity_weight: 1.0
presence_weight: 0.5
smoothness_weight: 0.5
optimizer:
lr: 1.0e-4
epochs: 50
batch_size: 2
seed: 0
Training in affine-native space requires the relative path to the affine-only-aligned native maps:
learned_atlas:
training_space: affine_native
affine_native_relpath: tractography/probtrackx2/run/native_streamlines/fdt_paths.nii.gz
Notes#
emit_baseline_mapsis validated againstATLAS_STATISTIC_NAMES(theatlasmap set): any unknown map name, or omittingmean, raises aConfigurationErrorat load time.A
model_validatoralso enforces thataffine_nativetraining suppliesaffine_native_relpath.The learned output directory (
output_subdir, defaultlearned_atlas) coexists with the averagingatlasdirectory, so both templates can be built into the sameoutputs/tree without clobbering.For the downstream consumers of the emitted maps, see
atlas(averaging counterpart),transforms(atlas_to_patient), andtract_similarity.