Source code for thesis.workflows.learned_atlas._params

"""Config schema and constants for the learned_atlas cohort workflow.

The :class:`LearnedAtlasConfig` model is registered as the ``learned_atlas``
config namespace by the ``@workflow`` decorator (see ``workflow.py``). Any
``learned_atlas:`` block in a YAML config is validated against this schema at
config-load time and reachable from the workflow body via
``getattr(config, \"learned_atlas\")``.

The learned (deformable VoxelMorph/AtlasMorph) template remains drop-in
compatible with the averaging atlas: it emits the same
:data:`thesis.workflows.atlas._params.ATLAS_STATISTIC_NAMES` map set so the
existing ``atlas_to_patient`` and ``tract_similarity`` paths keep working.

This module depends only on pydantic + ``thesis.core.config.validators`` (and
the atlas params for the map-name source of truth); it never imports torch, so
importing the package is safe without the optional ``ml`` extra.
"""

from __future__ import annotations

from typing import List, Literal, Optional, Tuple

from pydantic import Field, field_validator, model_validator

from thesis.core.config.validators import BaseConfig
from thesis.workflows.atlas._params import ATLAS_STATISTIC_NAMES

# ---------------------------------------------------------------------------
# Tuning constants
# ---------------------------------------------------------------------------

DEFAULT_INT_STEPS: int = 7
"""Scaling-and-squaring integration steps (2**7 sub-steps): VoxelMorph default."""

DEFAULT_LOG_OFFSET: float = 1.0e-3
"""Additive offset before log-transforming sparse TDI volumes for the NCC term."""

# Source of truth for the emitted-map contract; reused (not re-listed) so the
# learned atlas can never silently diverge from the averaging atlas.
LEARNED_ATLAS_BASELINE_MAPS: Tuple[str, ...] = ATLAS_STATISTIC_NAMES


# ---------------------------------------------------------------------------
# Nested config sub-models
# ---------------------------------------------------------------------------


[docs] class LearnedAtlasModelConfig(BaseConfig): """Network / template architecture knobs for the learned atlas. Attributes: channels: Input channels (1 for a single whole-map / per-bundle TDI volume). n_templates: Number of learnable base templates ``K``. ``1`` is the single sharp template; ``>1`` selects the Phase-2 mixture-of-K stub. int_steps: Scaling-and-squaring integration steps for the diffeomorphic velocity field (guarantees det(Jacobian) > 0). enc_features: Registration U-Net encoder channel widths (coarse-to-fine). dec_features: Registration U-Net decoder channel widths. """ channels: int = Field( default=1, ge=1, description="Input channels (1 for single whole-map/per-bundle TDI volume).", ) n_templates: int = Field( default=1, ge=1, description=( "Number of learnable base templates K. 1 = single sharp template; " ">1 = Phase-2 mixture-of-K stub (the training body trains K=1 and " "logs a NotImplemented warning)." ), ) int_steps: int = Field( default=DEFAULT_INT_STEPS, ge=0, le=10, description=( "Scaling-and-squaring integration steps for the diffeomorphic warp. " "Higher = stronger invertibility (det(Jac) > 0), more memory." ), ) enc_features: List[int] = Field( default=[16, 32, 32, 32], description="Registration U-Net encoder channel widths (coarse-to-fine).", ) dec_features: List[int] = Field( default=[32, 32, 32, 16, 16], description="Registration U-Net decoder channel widths.", ) @field_validator("enc_features", "dec_features") @classmethod def _check_feature_widths(cls, v: List[int]) -> List[int]: """Require non-empty lists of strictly positive channel widths.""" if len(v) < 1: raise ValueError("feature width lists must contain at least one entry") if any(c < 1 for c in v): raise ValueError("all feature widths must be >= 1") return v
[docs] class LearnedAtlasLossConfig(BaseConfig): """Loss-term weights for learned atlas training. Attributes: similarity_weight: Weight on the log-domain local-NCC reconstruction term. presence_weight: Weight on the presence-weighted soft-Dice term on the true (subject) support; guards against the do-nothing / regression- to-mean failure mode that union-support correlation flatters. smoothness_weight: Weight on the deformation-field gradient regulariser. log_offset: Additive offset before the log transform of sparse TDI volumes. """ similarity_weight: float = Field( default=1.0, ge=0.0, description="Weight on the log-NCC reconstruction term." ) presence_weight: float = Field( default=0.5, ge=0.0, description=( "Weight on the presence-weighted soft-Dice term on the subject's " "true support (anti regression-to-mean)." ), ) smoothness_weight: float = Field( default=0.5, ge=0.0, description="Weight on the deformation-field gradient (smoothness) regulariser.", ) log_offset: float = Field( default=DEFAULT_LOG_OFFSET, gt=0.0, description="Additive offset before log-transforming sparse TDI volumes.", )
[docs] class LearnedAtlasOptimizerConfig(BaseConfig): """Optimizer / training-loop knobs for learned atlas training. Attributes: lr: Adam learning rate for the joint template + U-Net optimisation. epochs: Number of training epochs over the cohort. batch_size: Subjects per gradient step (bounded by GPU VRAM for ~145^3). seed: RNG seed for reproducible template init and shuffling. """ lr: float = Field(default=1.0e-4, gt=0.0, description="Adam learning rate.") epochs: int = Field(default=50, ge=1, description="Training epochs over the cohort.") batch_size: int = Field( default=2, ge=1, description="Subjects per gradient step (GPU VRAM bound for ~145^3 volumes).", ) seed: int = Field(default=0, ge=0, description="RNG seed for reproducible training.")
# --------------------------------------------------------------------------- # Top-level config namespace model # ---------------------------------------------------------------------------
[docs] class LearnedAtlasConfig(BaseConfig): """Pydantic schema for the ``learned_atlas`` top-level config key. Registered with the framework by ``@workflow(config_namespace=\"learned_atlas\", config_schema=LearnedAtlasConfig)``; any ``learned_atlas:`` block in a YAML config is type-checked against this model and reachable via ``getattr(config, \"learned_atlas\")`` from the workflow body. Attributes: enabled: Master switch for learned atlas training. training_space: ``template_native`` (already template-registered warped volumes; the only space that is byte-compatible / drop-in for ``atlas_to_patient``) or ``affine_native`` (affine-only-aligned native maps, avoids the double-registration confound but is NOT consumable by the existing template-space ``atlas_to_patient`` jobs). device: Torch device string ('cuda', 'cuda:0', 'cpu'). dtype: Training/autocast dtype name ('float32', 'float16', 'bfloat16'). affine_native_relpath: Relative path under each patient directory to the affine-only-aligned native map; required when ``training_space='affine_native'``. May be a per-tract/bundle path and degrades gracefully to the whole-map case. emit_baseline_maps: Averaging-atlas maps to emit. Defaults to the FULL set so the learned dir is a true drop-in for any atlas_to_patient job config; must be a subset of ATLAS_STATISTIC_NAMES and include 'mean'. output_subdir: Cohort output subdirectory for the learned template and per-subject fields. Single source of truth for the output location; the @produces CohortDir literal must equal this default. min_subjects: Minimum cohort size required to train. save_fields: Write per-subject deformation fields alongside the maps. verify_jacobian: Verify every predicted field has det(Jacobian) > 0. model: Network / template architecture knobs. loss: Loss-term weights. optimizer: Optimizer / training-loop knobs. """ enabled: bool = Field(default=True, description="Enable learned atlas training.") training_space: Literal["template_native", "affine_native"] = Field( default="template_native", description=( "Training space. 'template_native' (default) trains on the " "template-registered warped volumes and is byte-compatible / " "drop-in for atlas_to_patient. 'affine_native' trains on " "affine-only-aligned native maps (avoids the double-registration " "confound) but the resulting template is NOT consumable by the " "existing template-space atlas_to_patient jobs without an extra warp." ), ) device: str = Field( default="cuda", description="Torch device for training ('cuda', 'cuda:0', 'cpu').", ) dtype: str = Field( default="float32", description="Training/autocast dtype ('float32', 'float16', 'bfloat16').", ) affine_native_relpath: Optional[str] = Field( default=None, description=( "Relative path under each patient dir to the affine-only-aligned " "native map; required when training_space='affine_native'. May be a " "per-tract/bundle path; degrades to the whole-map case." ), ) emit_baseline_maps: List[str] = Field( default=list(ATLAS_STATISTIC_NAMES), description=( "Averaging-atlas maps to emit from the learned template. Defaults to " "the full ATLAS_STATISTIC_NAMES set so the learned dir is a true " "drop-in for any atlas_to_patient job. Must be a subset of " f"{list(ATLAS_STATISTIC_NAMES)} and include 'mean'." ), ) output_subdir: str = Field( default="learned_atlas", description=( "Cohort output subdirectory for the learned template " "(atlas_mean.nii.gz et al.) and per-subject fields. Coexists with " "'atlas'. MUST match the @produces CohortDir literal." ), ) min_subjects: int = Field(default=5, ge=1, description="Minimum cohort size required to train.") save_fields: bool = Field(default=True, description="Write per-subject deformation fields.") verify_jacobian: bool = Field( default=True, description="Verify every predicted field has det(Jacobian) > 0 (no folding).", ) model: LearnedAtlasModelConfig = Field( default_factory=LearnedAtlasModelConfig, description="Network / template architecture knobs.", ) loss: LearnedAtlasLossConfig = Field( default_factory=LearnedAtlasLossConfig, description="Loss-term weights.", ) optimizer: LearnedAtlasOptimizerConfig = Field( default_factory=LearnedAtlasOptimizerConfig, description="Optimizer / training-loop knobs.", ) @field_validator("device") @classmethod def _check_device(cls, v: str) -> str: """Validate the device string shape without importing torch.""" import re if not re.fullmatch(r"cpu|cuda(:\d+)?", v.strip()): raise ValueError(f"device must be 'cpu', 'cuda', or 'cuda:N', got {v!r}") return v.strip() @field_validator("dtype") @classmethod def _check_dtype(cls, v: str) -> str: """Validate the torch dtype name without importing torch.""" allowed = {"float16", "float32", "bfloat16"} normalized = v.strip().strip('"').strip("'") if normalized not in allowed: raise ValueError(f"dtype must be one of {sorted(allowed)}, got {v!r}") return normalized @model_validator(mode="after") def _validate_post(self) -> "LearnedAtlasConfig": """Cross-field checks: emit-map subset, affine-native path presence.""" unknown = [m for m in self.emit_baseline_maps if m not in ATLAS_STATISTIC_NAMES] if unknown: raise ValueError( f"emit_baseline_maps contains unknown map(s) {unknown}; " f"valid options are {list(ATLAS_STATISTIC_NAMES)}." ) if "mean" not in self.emit_baseline_maps: raise ValueError( "emit_baseline_maps must include 'mean' (atlas_mean.nii.gz is the " "learned template consumed by atlas_to_patient)." ) if self.training_space == "affine_native" and not self.affine_native_relpath: raise ValueError( "training_space='affine_native' requires affine_native_relpath " "to point at the affine-only-aligned native map." ) return self