"""Config schema and constants for the learned_atlas cohort workflow.

The :class:`LearnedAtlasConfig` model is registered as the ``learned_atlas``
config namespace by the ``@workflow`` decorator (see ``workflow.py``). Any
``learned_atlas:`` block in a YAML config is validated against this schema at
config-load time and reachable from the workflow body via
``getattr(config, \"learned_atlas\")``.

The learned (deformable VoxelMorph/AtlasMorph) template remains drop-in
compatible with the averaging atlas: it emits the same
:data:`thesis.workflows.atlas._params.ATLAS_STATISTIC_NAMES` map set so the
existing ``atlas_to_patient`` and ``tract_similarity`` paths keep working.

This module depends only on pydantic + ``thesis.core.config.validators`` (and
the atlas params for the map-name source of truth); it never imports torch, so
importing the package is safe without the optional ``ml`` extra.
"""

from __future__ import annotations

from typing import List, Literal, Optional, Tuple

from pydantic import Field, field_validator, model_validator

from thesis.core.config.validators import BaseConfig
from thesis.workflows.atlas._params import ATLAS_STATISTIC_NAMES

# ---------------------------------------------------------------------------
# Tuning constants
# ---------------------------------------------------------------------------

DEFAULT_INT_STEPS: int = 7
"""Scaling-and-squaring integration steps (2**7 sub-steps): VoxelMorph default."""

DEFAULT_LOG_OFFSET: float = 1.0e-3
"""Additive offset before log-transforming sparse TDI volumes for the NCC term."""

# Source of truth for the emitted-map contract; reused (not re-listed) so the
# learned atlas can never silently diverge from the averaging atlas.
LEARNED_ATLAS_BASELINE_MAPS: Tuple[str, ...] = ATLAS_STATISTIC_NAMES


# ---------------------------------------------------------------------------
# Nested config sub-models
# ---------------------------------------------------------------------------


class LearnedAtlasModelConfig(BaseConfig):
    """Network / template architecture knobs for the learned atlas.

    Attributes:
        channels: Input channels (1 for a single whole-map / per-bundle TDI volume).
        n_templates: Number of learnable base templates ``K``. ``1`` is the
            single sharp template; ``>1`` selects the Phase-2 mixture-of-K stub.
        int_steps: Scaling-and-squaring integration steps for the diffeomorphic
            velocity field (guarantees det(Jacobian) > 0).
        enc_features: Registration U-Net encoder channel widths (coarse-to-fine).
        dec_features: Registration U-Net decoder channel widths.
    """

    channels: int = Field(
        default=1,
        ge=1,
        description="Input channels (1 for single whole-map/per-bundle TDI volume).",
    )
    n_templates: int = Field(
        default=1,
        ge=1,
        description=(
            "Number of learnable base templates K. 1 = single sharp template; "
            ">1 = Phase-2 mixture-of-K stub (the training body trains K=1 and "
            "logs a NotImplemented warning)."
        ),
    )
    int_steps: int = Field(
        default=DEFAULT_INT_STEPS,
        ge=0,
        le=10,
        description=(
            "Scaling-and-squaring integration steps for the diffeomorphic warp. "
            "Higher = stronger invertibility (det(Jac) > 0), more memory."
        ),
    )
    enc_features: List[int] = Field(
        default=[16, 32, 32, 32],
        description="Registration U-Net encoder channel widths (coarse-to-fine).",
    )
    dec_features: List[int] = Field(
        default=[32, 32, 32, 16, 16],
        description="Registration U-Net decoder channel widths.",
    )

    @field_validator("enc_features", "dec_features")
    @classmethod
    def _check_feature_widths(cls, v: List[int]) -> List[int]:
        """Require non-empty lists of strictly positive channel widths."""
        if len(v) < 1:
            raise ValueError("feature width lists must contain at least one entry")
        if any(c < 1 for c in v):
            raise ValueError("all feature widths must be >= 1")
        return v


class LearnedAtlasLossConfig(BaseConfig):
    """Loss-term weights for learned atlas training.

    Attributes:
        similarity_weight: Weight on the log-domain local-NCC reconstruction term.
        presence_weight: Weight on the presence-weighted soft-Dice term on the
            true (subject) support; guards against the do-nothing / regression-
            to-mean failure mode that union-support correlation flatters.
        smoothness_weight: Weight on the deformation-field gradient regulariser.
        log_offset: Additive offset before the log transform of sparse TDI volumes.
    """

    similarity_weight: float = Field(
        default=1.0, ge=0.0, description="Weight on the log-NCC reconstruction term."
    )
    presence_weight: float = Field(
        default=0.5,
        ge=0.0,
        description=(
            "Weight on the presence-weighted soft-Dice term on the subject's "
            "true support (anti regression-to-mean)."
        ),
    )
    smoothness_weight: float = Field(
        default=0.5,
        ge=0.0,
        description="Weight on the deformation-field gradient (smoothness) regulariser.",
    )
    log_offset: float = Field(
        default=DEFAULT_LOG_OFFSET,
        gt=0.0,
        description="Additive offset before log-transforming sparse TDI volumes.",
    )


class LearnedAtlasOptimizerConfig(BaseConfig):
    """Optimizer / training-loop knobs for learned atlas training.

    Attributes:
        lr: Adam learning rate for the joint template + U-Net optimisation.
        epochs: Number of training epochs over the cohort.
        batch_size: Subjects per gradient step (bounded by GPU VRAM for ~145^3).
        seed: RNG seed for reproducible template init and shuffling.
    """

    lr: float = Field(default=1.0e-4, gt=0.0, description="Adam learning rate.")
    epochs: int = Field(default=50, ge=1, description="Training epochs over the cohort.")
    batch_size: int = Field(
        default=2,
        ge=1,
        description="Subjects per gradient step (GPU VRAM bound for ~145^3 volumes).",
    )
    seed: int = Field(default=0, ge=0, description="RNG seed for reproducible training.")


# ---------------------------------------------------------------------------
# Top-level config namespace model
# ---------------------------------------------------------------------------


class LearnedAtlasConfig(BaseConfig):
    """Pydantic schema for the ``learned_atlas`` top-level config key.

    Registered with the framework by ``@workflow(config_namespace=\"learned_atlas\",
    config_schema=LearnedAtlasConfig)``; any ``learned_atlas:`` block in a YAML
    config is type-checked against this model and reachable via
    ``getattr(config, \"learned_atlas\")`` from the workflow body.

    Attributes:
        enabled: Master switch for learned atlas training.
        training_space: ``template_native`` (already template-registered warped
            volumes; the only space that is byte-compatible / drop-in for
            ``atlas_to_patient``) or ``affine_native`` (affine-only-aligned
            native maps, avoids the double-registration confound but is NOT
            consumable by the existing template-space ``atlas_to_patient`` jobs).
        device: Torch device string ('cuda', 'cuda:0', 'cpu').
        dtype: Training/autocast dtype name ('float32', 'float16', 'bfloat16').
        affine_native_relpath: Relative path under each patient directory to the
            affine-only-aligned native map; required when
            ``training_space='affine_native'``. May be a per-tract/bundle path
            and degrades gracefully to the whole-map case.
        emit_baseline_maps: Averaging-atlas maps to emit. Defaults to the FULL
            set so the learned dir is a true drop-in for any atlas_to_patient
            job config; must be a subset of ATLAS_STATISTIC_NAMES and include
            'mean'.
        output_subdir: Cohort output subdirectory for the learned template and
            per-subject fields. Single source of truth for the output location;
            the @produces CohortDir literal must equal this default.
        min_subjects: Minimum cohort size required to train.
        save_fields: Write per-subject deformation fields alongside the maps.
        verify_jacobian: Verify every predicted field has det(Jacobian) > 0.
        model: Network / template architecture knobs.
        loss: Loss-term weights.
        optimizer: Optimizer / training-loop knobs.
    """

    enabled: bool = Field(default=True, description="Enable learned atlas training.")
    training_space: Literal["template_native", "affine_native"] = Field(
        default="template_native",
        description=(
            "Training space. 'template_native' (default) trains on the "
            "template-registered warped volumes and is byte-compatible / "
            "drop-in for atlas_to_patient. 'affine_native' trains on "
            "affine-only-aligned native maps (avoids the double-registration "
            "confound) but the resulting template is NOT consumable by the "
            "existing template-space atlas_to_patient jobs without an extra warp."
        ),
    )
    device: str = Field(
        default="cuda",
        description="Torch device for training ('cuda', 'cuda:0', 'cpu').",
    )
    dtype: str = Field(
        default="float32",
        description="Training/autocast dtype ('float32', 'float16', 'bfloat16').",
    )
    affine_native_relpath: Optional[str] = Field(
        default=None,
        description=(
            "Relative path under each patient dir to the affine-only-aligned "
            "native map; required when training_space='affine_native'. May be a "
            "per-tract/bundle path; degrades to the whole-map case."
        ),
    )
    emit_baseline_maps: List[str] = Field(
        default=list(ATLAS_STATISTIC_NAMES),
        description=(
            "Averaging-atlas maps to emit from the learned template. Defaults to "
            "the full ATLAS_STATISTIC_NAMES set so the learned dir is a true "
            "drop-in for any atlas_to_patient job. Must be a subset of "
            f"{list(ATLAS_STATISTIC_NAMES)} and include 'mean'."
        ),
    )
    output_subdir: str = Field(
        default="learned_atlas",
        description=(
            "Cohort output subdirectory for the learned template "
            "(atlas_mean.nii.gz et al.) and per-subject fields. Coexists with "
            "'atlas'. MUST match the @produces CohortDir literal."
        ),
    )
    min_subjects: int = Field(default=5, ge=1, description="Minimum cohort size required to train.")
    save_fields: bool = Field(default=True, description="Write per-subject deformation fields.")
    verify_jacobian: bool = Field(
        default=True,
        description="Verify every predicted field has det(Jacobian) > 0 (no folding).",
    )
    model: LearnedAtlasModelConfig = Field(
        default_factory=LearnedAtlasModelConfig,
        description="Network / template architecture knobs.",
    )
    loss: LearnedAtlasLossConfig = Field(
        default_factory=LearnedAtlasLossConfig,
        description="Loss-term weights.",
    )
    optimizer: LearnedAtlasOptimizerConfig = Field(
        default_factory=LearnedAtlasOptimizerConfig,
        description="Optimizer / training-loop knobs.",
    )

    @field_validator("device")
    @classmethod
    def _check_device(cls, v: str) -> str:
        """Validate the device string shape without importing torch."""
        import re

        if not re.fullmatch(r"cpu|cuda(:\d+)?", v.strip()):
            raise ValueError(f"device must be 'cpu', 'cuda', or 'cuda:N', got {v!r}")
        return v.strip()

    @field_validator("dtype")
    @classmethod
    def _check_dtype(cls, v: str) -> str:
        """Validate the torch dtype name without importing torch."""
        allowed = {"float16", "float32", "bfloat16"}
        normalized = v.strip().strip('"').strip("'")
        if normalized not in allowed:
            raise ValueError(f"dtype must be one of {sorted(allowed)}, got {v!r}")
        return normalized

    @model_validator(mode="after")
    def _validate_post(self) -> "LearnedAtlasConfig":
        """Cross-field checks: emit-map subset, affine-native path presence."""
        unknown = [m for m in self.emit_baseline_maps if m not in ATLAS_STATISTIC_NAMES]
        if unknown:
            raise ValueError(
                f"emit_baseline_maps contains unknown map(s) {unknown}; "
                f"valid options are {list(ATLAS_STATISTIC_NAMES)}."
            )
        if "mean" not in self.emit_baseline_maps:
            raise ValueError(
                "emit_baseline_maps must include 'mean' (atlas_mean.nii.gz is the "
                "learned template consumed by atlas_to_patient)."
            )
        if self.training_space == "affine_native" and not self.affine_native_relpath:
            raise ValueError(
                "training_space='affine_native' requires affine_native_relpath "
                "to point at the affine-only-aligned native map."
            )
        return self
