learned_atlas — learned deformable tract-density template#

Schema: LearnedAtlasConfig in src/thesis/workflows/learned_atlas/_params.py. Unlike the core sections, this namespace is owned by the learned_atlas workflow (registered via @workflow(config_namespace="learned_atlas", config_schema=LearnedAtlasConfig)), so a learned_atlas: block is validated against this model only when the workflow is on the registry path. It is reachable from the workflow body via getattr(config, "learned_atlas").

The learned_atlas workflow is cohort-level — -p/--patient-id and --all are ignored; it scans the configured output directory for per-subject tractography. It trains a deformable (VoxelMorph/AtlasMorph) tract-density template and emits the same five maps as the averaging atlas (atlas_mean.nii.gz, atlas_std.nii.gz, atlas_std_error.nii.gz, atlas_cov.nii.gz, atlas_prob_threshold.nii.gz), so the existing atlas_to_patient and tract_similarity paths keep working unchanged.

Training requires PyTorch, which ships in the optional ml extra (pip install -e ".[ml]"). torch is lazy-imported only inside the training node, so importing the package — and validating this config — never needs the extra.

thesis run -w learned_atlas -c default --protocol hcp

Top-level fields#

Field

Type

Default

Constraints

Description

enabled

bool

True

Master switch for learned atlas training.

training_space

"template_native" | "affine_native"

"template_native"

one of the two literals

template_native trains on the template-registered warped volumes and is byte-compatible / drop-in for atlas_to_patient. affine_native trains on affine-only-aligned native maps (avoids the double-registration confound) but the resulting template is not consumable by template-space atlas_to_patient jobs without an extra warp.

device

str

"cuda"

must match cpu|cuda(:\d+)?

Torch device for training (cpu, cuda, cuda:0).

dtype

str

"float32"

one of float16 / float32 / bfloat16

Training / autocast dtype.

affine_native_relpath

str | None

None

required when training_space="affine_native"

Relative path under each patient directory to the affine-only-aligned native map. May be a per-tract/bundle path and degrades gracefully to the whole-map case.

emit_baseline_maps

List[str]

full atlas-map set (mean, std, std_error, cov, prob_threshold)

must be a subset of the atlas maps and include mean

Averaging-atlas maps to emit from the learned template. Defaults to the full set so the output dir is a true drop-in for any atlas_to_patient job.

output_subdir

str

"learned_atlas"

must match the @produces CohortDir literal

Cohort output subdirectory for the learned template (atlas_mean.nii.gz et al.) and per-subject fields. Coexists with atlas.

min_subjects

int

5

1

Minimum cohort size required to train.

save_fields

bool

True

Write per-subject deformation fields alongside the maps.

verify_jacobian

bool

True

Verify every predicted field has det(Jacobian) > 0 (no folding).

model

LearnedAtlasModelConfig

(see below)

Network / template architecture knobs.

loss

LearnedAtlasLossConfig

(see below)

Loss-term weights.

optimizer

LearnedAtlasOptimizerConfig

(see below)

Optimizer / training-loop knobs.

The mean map (atlas_mean.nii.gz) is the learned template consumed by atlas_to_patient, which is why emit_baseline_maps must always include it.

modelLearnedAtlasModelConfig#

Field

Type

Default

Constraints

Description

channels

int

1

1

Input channels (1 for a single whole-map / per-bundle TDI volume).

n_templates

int

1

1

Number of learnable base templates K. 1 is the single sharp template; >1 selects the Phase-2 mixture-of-K stub (the training body still trains K=1 and logs a NotImplemented warning).

int_steps

int

7

0 x 10

Scaling-and-squaring integration steps for the diffeomorphic warp. Higher = stronger invertibility (det(Jac) > 0), more memory.

enc_features

List[int]

[16, 32, 32, 32]

non-empty; each 1

Registration U-Net encoder channel widths (coarse-to-fine).

dec_features

List[int]

[32, 32, 32, 16, 16]

non-empty; each 1

Registration U-Net decoder channel widths.

lossLearnedAtlasLossConfig#

Field

Type

Default

Constraints

Description

similarity_weight

float

1.0

0.0

Weight on the log-domain local-NCC reconstruction term.

presence_weight

float

0.5

0.0

Weight on the presence-weighted soft-Dice term on the subject’s true support; guards against the regression-to-mean failure mode that union-support correlation flatters.

smoothness_weight

float

0.5

0.0

Weight on the deformation-field gradient (smoothness) regulariser.

log_offset

float

1.0e-3

> 0.0

Additive offset before log-transforming sparse TDI volumes for the NCC term.

optimizerLearnedAtlasOptimizerConfig#

Field

Type

Default

Constraints

Description

lr

float

1.0e-4

> 0.0

Adam learning rate for the joint template + U-Net optimisation.

epochs

int

50

1

Training epochs over the cohort.

batch_size

int

2

1

Subjects per gradient step (bounded by GPU VRAM for ~145³ volumes).

seed

int

0

0

RNG seed for reproducible template init and shuffling.

Example#

learned_atlas:
  enabled: true
  training_space: template_native
  device: cuda:0
  dtype: float32
  output_subdir: learned_atlas
  min_subjects: 5
  save_fields: true
  verify_jacobian: true
  model:
    n_templates: 1
    int_steps: 7
  loss:
    similarity_weight: 1.0
    presence_weight: 0.5
    smoothness_weight: 0.5
  optimizer:
    lr: 1.0e-4
    epochs: 50
    batch_size: 2
    seed: 0

Training in affine-native space requires the relative path to the affine-only-aligned native maps:

learned_atlas:
  training_space: affine_native
  affine_native_relpath: tractography/probtrackx2/run/native_streamlines/fdt_paths.nii.gz

Notes#

  • emit_baseline_maps is validated against ATLAS_STATISTIC_NAMES (the atlas map set): any unknown map name, or omitting mean, raises a ConfigurationError at load time.

  • A model_validator also enforces that affine_native training supplies affine_native_relpath.

  • The learned output directory (output_subdir, default learned_atlas) coexists with the averaging atlas directory, so both templates can be built into the same outputs/ tree without clobbering.

  • For the downstream consumers of the emitted maps, see atlas (averaging counterpart), transforms (atlas_to_patient), and tract_similarity.