"""Config schema and constants for the learned_atlas cohort workflow.
The :class:`LearnedAtlasConfig` model is registered as the ``learned_atlas``
config namespace by the ``@workflow`` decorator (see ``workflow.py``). Any
``learned_atlas:`` block in a YAML config is validated against this schema at
config-load time and reachable from the workflow body via
``getattr(config, \"learned_atlas\")``.
The learned (deformable VoxelMorph/AtlasMorph) template remains drop-in
compatible with the averaging atlas: it emits the same
:data:`thesis.workflows.atlas._params.ATLAS_STATISTIC_NAMES` map set so the
existing ``atlas_to_patient`` and ``tract_similarity`` paths keep working.
This module depends only on pydantic + ``thesis.core.config.validators`` (and
the atlas params for the map-name source of truth); it never imports torch, so
importing the package is safe without the optional ``ml`` extra.
"""
from __future__ import annotations
from typing import List, Literal, Optional, Tuple
from pydantic import Field, field_validator, model_validator
from thesis.core.config.validators import BaseConfig
from thesis.workflows.atlas._params import ATLAS_STATISTIC_NAMES
# ---------------------------------------------------------------------------
# Tuning constants
# ---------------------------------------------------------------------------
DEFAULT_INT_STEPS: int = 7
"""Scaling-and-squaring integration steps (2**7 sub-steps): VoxelMorph default."""
DEFAULT_LOG_OFFSET: float = 1.0e-3
"""Additive offset before log-transforming sparse TDI volumes for the NCC term."""
# Source of truth for the emitted-map contract; reused (not re-listed) so the
# learned atlas can never silently diverge from the averaging atlas.
LEARNED_ATLAS_BASELINE_MAPS: Tuple[str, ...] = ATLAS_STATISTIC_NAMES
# ---------------------------------------------------------------------------
# Nested config sub-models
# ---------------------------------------------------------------------------
[docs]
class LearnedAtlasModelConfig(BaseConfig):
"""Network / template architecture knobs for the learned atlas.
Attributes:
channels: Input channels (1 for a single whole-map / per-bundle TDI volume).
n_templates: Number of learnable base templates ``K``. ``1`` is the
single sharp template; ``>1`` selects the Phase-2 mixture-of-K stub.
int_steps: Scaling-and-squaring integration steps for the diffeomorphic
velocity field (guarantees det(Jacobian) > 0).
enc_features: Registration U-Net encoder channel widths (coarse-to-fine).
dec_features: Registration U-Net decoder channel widths.
"""
channels: int = Field(
default=1,
ge=1,
description="Input channels (1 for single whole-map/per-bundle TDI volume).",
)
n_templates: int = Field(
default=1,
ge=1,
description=(
"Number of learnable base templates K. 1 = single sharp template; "
">1 = Phase-2 mixture-of-K stub (the training body trains K=1 and "
"logs a NotImplemented warning)."
),
)
int_steps: int = Field(
default=DEFAULT_INT_STEPS,
ge=0,
le=10,
description=(
"Scaling-and-squaring integration steps for the diffeomorphic warp. "
"Higher = stronger invertibility (det(Jac) > 0), more memory."
),
)
enc_features: List[int] = Field(
default=[16, 32, 32, 32],
description="Registration U-Net encoder channel widths (coarse-to-fine).",
)
dec_features: List[int] = Field(
default=[32, 32, 32, 16, 16],
description="Registration U-Net decoder channel widths.",
)
@field_validator("enc_features", "dec_features")
@classmethod
def _check_feature_widths(cls, v: List[int]) -> List[int]:
"""Require non-empty lists of strictly positive channel widths."""
if len(v) < 1:
raise ValueError("feature width lists must contain at least one entry")
if any(c < 1 for c in v):
raise ValueError("all feature widths must be >= 1")
return v
[docs]
class LearnedAtlasLossConfig(BaseConfig):
"""Loss-term weights for learned atlas training.
Attributes:
similarity_weight: Weight on the log-domain local-NCC reconstruction term.
presence_weight: Weight on the presence-weighted soft-Dice term on the
true (subject) support; guards against the do-nothing / regression-
to-mean failure mode that union-support correlation flatters.
smoothness_weight: Weight on the deformation-field gradient regulariser.
log_offset: Additive offset before the log transform of sparse TDI volumes.
"""
similarity_weight: float = Field(
default=1.0, ge=0.0, description="Weight on the log-NCC reconstruction term."
)
presence_weight: float = Field(
default=0.5,
ge=0.0,
description=(
"Weight on the presence-weighted soft-Dice term on the subject's "
"true support (anti regression-to-mean)."
),
)
smoothness_weight: float = Field(
default=0.5,
ge=0.0,
description="Weight on the deformation-field gradient (smoothness) regulariser.",
)
log_offset: float = Field(
default=DEFAULT_LOG_OFFSET,
gt=0.0,
description="Additive offset before log-transforming sparse TDI volumes.",
)
[docs]
class LearnedAtlasOptimizerConfig(BaseConfig):
"""Optimizer / training-loop knobs for learned atlas training.
Attributes:
lr: Adam learning rate for the joint template + U-Net optimisation.
epochs: Number of training epochs over the cohort.
batch_size: Subjects per gradient step (bounded by GPU VRAM for ~145^3).
seed: RNG seed for reproducible template init and shuffling.
"""
lr: float = Field(default=1.0e-4, gt=0.0, description="Adam learning rate.")
epochs: int = Field(default=50, ge=1, description="Training epochs over the cohort.")
batch_size: int = Field(
default=2,
ge=1,
description="Subjects per gradient step (GPU VRAM bound for ~145^3 volumes).",
)
seed: int = Field(default=0, ge=0, description="RNG seed for reproducible training.")
# ---------------------------------------------------------------------------
# Top-level config namespace model
# ---------------------------------------------------------------------------
[docs]
class LearnedAtlasConfig(BaseConfig):
"""Pydantic schema for the ``learned_atlas`` top-level config key.
Registered with the framework by ``@workflow(config_namespace=\"learned_atlas\",
config_schema=LearnedAtlasConfig)``; any ``learned_atlas:`` block in a YAML
config is type-checked against this model and reachable via
``getattr(config, \"learned_atlas\")`` from the workflow body.
Attributes:
enabled: Master switch for learned atlas training.
training_space: ``template_native`` (already template-registered warped
volumes; the only space that is byte-compatible / drop-in for
``atlas_to_patient``) or ``affine_native`` (affine-only-aligned
native maps, avoids the double-registration confound but is NOT
consumable by the existing template-space ``atlas_to_patient`` jobs).
device: Torch device string ('cuda', 'cuda:0', 'cpu').
dtype: Training/autocast dtype name ('float32', 'float16', 'bfloat16').
affine_native_relpath: Relative path under each patient directory to the
affine-only-aligned native map; required when
``training_space='affine_native'``. May be a per-tract/bundle path
and degrades gracefully to the whole-map case.
emit_baseline_maps: Averaging-atlas maps to emit. Defaults to the FULL
set so the learned dir is a true drop-in for any atlas_to_patient
job config; must be a subset of ATLAS_STATISTIC_NAMES and include
'mean'.
output_subdir: Cohort output subdirectory for the learned template and
per-subject fields. Single source of truth for the output location;
the @produces CohortDir literal must equal this default.
min_subjects: Minimum cohort size required to train.
save_fields: Write per-subject deformation fields alongside the maps.
verify_jacobian: Verify every predicted field has det(Jacobian) > 0.
model: Network / template architecture knobs.
loss: Loss-term weights.
optimizer: Optimizer / training-loop knobs.
"""
enabled: bool = Field(default=True, description="Enable learned atlas training.")
training_space: Literal["template_native", "affine_native"] = Field(
default="template_native",
description=(
"Training space. 'template_native' (default) trains on the "
"template-registered warped volumes and is byte-compatible / "
"drop-in for atlas_to_patient. 'affine_native' trains on "
"affine-only-aligned native maps (avoids the double-registration "
"confound) but the resulting template is NOT consumable by the "
"existing template-space atlas_to_patient jobs without an extra warp."
),
)
device: str = Field(
default="cuda",
description="Torch device for training ('cuda', 'cuda:0', 'cpu').",
)
dtype: str = Field(
default="float32",
description="Training/autocast dtype ('float32', 'float16', 'bfloat16').",
)
affine_native_relpath: Optional[str] = Field(
default=None,
description=(
"Relative path under each patient dir to the affine-only-aligned "
"native map; required when training_space='affine_native'. May be a "
"per-tract/bundle path; degrades to the whole-map case."
),
)
emit_baseline_maps: List[str] = Field(
default=list(ATLAS_STATISTIC_NAMES),
description=(
"Averaging-atlas maps to emit from the learned template. Defaults to "
"the full ATLAS_STATISTIC_NAMES set so the learned dir is a true "
"drop-in for any atlas_to_patient job. Must be a subset of "
f"{list(ATLAS_STATISTIC_NAMES)} and include 'mean'."
),
)
output_subdir: str = Field(
default="learned_atlas",
description=(
"Cohort output subdirectory for the learned template "
"(atlas_mean.nii.gz et al.) and per-subject fields. Coexists with "
"'atlas'. MUST match the @produces CohortDir literal."
),
)
min_subjects: int = Field(default=5, ge=1, description="Minimum cohort size required to train.")
save_fields: bool = Field(default=True, description="Write per-subject deformation fields.")
verify_jacobian: bool = Field(
default=True,
description="Verify every predicted field has det(Jacobian) > 0 (no folding).",
)
model: LearnedAtlasModelConfig = Field(
default_factory=LearnedAtlasModelConfig,
description="Network / template architecture knobs.",
)
loss: LearnedAtlasLossConfig = Field(
default_factory=LearnedAtlasLossConfig,
description="Loss-term weights.",
)
optimizer: LearnedAtlasOptimizerConfig = Field(
default_factory=LearnedAtlasOptimizerConfig,
description="Optimizer / training-loop knobs.",
)
@field_validator("device")
@classmethod
def _check_device(cls, v: str) -> str:
"""Validate the device string shape without importing torch."""
import re
if not re.fullmatch(r"cpu|cuda(:\d+)?", v.strip()):
raise ValueError(f"device must be 'cpu', 'cuda', or 'cuda:N', got {v!r}")
return v.strip()
@field_validator("dtype")
@classmethod
def _check_dtype(cls, v: str) -> str:
"""Validate the torch dtype name without importing torch."""
allowed = {"float16", "float32", "bfloat16"}
normalized = v.strip().strip('"').strip("'")
if normalized not in allowed:
raise ValueError(f"dtype must be one of {sorted(allowed)}, got {v!r}")
return normalized
@model_validator(mode="after")
def _validate_post(self) -> "LearnedAtlasConfig":
"""Cross-field checks: emit-map subset, affine-native path presence."""
unknown = [m for m in self.emit_baseline_maps if m not in ATLAS_STATISTIC_NAMES]
if unknown:
raise ValueError(
f"emit_baseline_maps contains unknown map(s) {unknown}; "
f"valid options are {list(ATLAS_STATISTIC_NAMES)}."
)
if "mean" not in self.emit_baseline_maps:
raise ValueError(
"emit_baseline_maps must include 'mean' (atlas_mean.nii.gz is the "
"learned template consumed by atlas_to_patient)."
)
if self.training_space == "affine_native" and not self.affine_native_relpath:
raise ValueError(
"training_space='affine_native' requires affine_native_relpath "
"to point at the affine-only-aligned native map."
)
return self