"""
Pydantic models for configuration validation.

These models ensure that configurations have the correct structure
and valid values before being used in processing pipelines.
"""

import re
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union, cast

from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator

from thesis.core.logging import get_logger
from thesis.core.utils import to_path

logger = get_logger(__name__)

# ---------------------------------------------------------------------------
# Enums (defined here to avoid circular imports with workflow modules)
# ---------------------------------------------------------------------------


class NormalizationMethod(str, Enum):
    """Method for normalizing streamline density volumes before atlas combination.

    Attributes:
        WAYTOTAL: Divide by waytotal (total streamline count). Default method
            for the FSL ProbTrackX2 backend; produces values representing the
            fraction of total streamlines passing through each voxel.
        MAX: Divide by the maximum voxel value in the volume.
            Scales each volume to [0, 1] range.
        SOFTMAX: Apply softmax normalization (exp(x) / sum(exp(x))).
            Produces a probability distribution that sums to 1.0 across the volume.
        STREAMLINE_DENSITY: MRtrix3 counterpart of ``WAYTOTAL``. Divides by
            the value stored in ``waytotal``, which the MRtrix3 workflow
            writes as the sum of SIFT2 per-streamline weights when SIFT2 is
            enabled (matching the SIFT2-weighted TDI numerator) or the raw
            streamline count when SIFT2 is off. Produces a streamline
            fraction in [0, 1].
    """

    WAYTOTAL = "waytotal"
    MAX = "max"
    SOFTMAX = "softmax"
    STREAMLINE_DENSITY = "streamline_density"


# ---------------------------------------------------------------------------
# Shared allowed-value sets for registration (referenced by both
# RegistrationConfig and the per-job RegistrationJobConfig so the two models
# never drift).
# ---------------------------------------------------------------------------

REGISTRATION_METHODS = ["ants", "fireants", "fsl", "dipy"]
REGISTRATION_MOVING_MODALITIES = ["t1", "t2"]
REGISTRATION_INTERPOLATIONS = [
    "Linear",
    "NearestNeighbor",
    "BSpline",
    "MultiLabel",
    "Gaussian",
    "CosineWindowedSinc",
    "WelchWindowedSinc",
    "HammingWindowedSinc",
    "LanczosWindowedSinc",
    "GenericLabel",
]
REGISTRATION_TRANSFORM_TYPES = ["Rigid", "Affine", "SyN"]


__all__ = [
    "NormalizationMethod",
    "BaseConfig",
    "PathConfig",
    "HardwareConfig",
    "AtlasConfig",
    "AtlasQCConfig",
    "S3Config",
    "NipypeConfig",
    "PreprocessingConfig",
    "FireantsRegistrationConfig",
    "RegistrationViewerConfig",
    "RegistrationJobConfig",
    "RegistrationConfig",
    "SegmentationConfig",
    "AtlasSourceConfig",
    "AtlasTransformConfig",
    "TransformJobConfig",
    "TransformsConfig",
    "SynthSegConfig",
    "TractographyConfig",
    "ValidationConfig",
    "QCConfig",
    "OutputSettingsConfig",
    "SideThresholdConfig",
    "HcpLooConfig",
    "TractSimilarityConfig",
    "ThresholdGridConfig",
    "TractSimilaritySweepConfig",
    "HCPConfig",
    "PipelineConfig",
]


class BaseConfig(BaseModel):
    """Base config model that rejects unknown fields to catch typos early."""

    model_config = ConfigDict(extra="forbid")


class PathConfig(BaseConfig):
    """Configuration for file paths and directories.

    The three read-only/output roots are independent — none is a parent of
    the others:

    Attributes:
        inputs_dir: Per-patient source data root. The patient's input
            directory is ``inputs_dir / <patient_id>``.
        assets_dir: Cohort-shared, read-only assets (templates, atlases,
            ROIs, reference images). Anchors ``DataFile``/``DataDir`` lookups.
        output_dir: Root for all workflow outputs/derivatives. The patient's
            output directory is ``output_dir / <patient_id>``.
        scratch_dir: Optional temporary/scratch directory override. When
            unset, scratch defaults to ``output_dir / <patient_id> / temp``.
        log_dir: Directory for runtime logs.
        scripts_dir: Optional directory of user-supplied workflow scripts
            (``*.py``).
    """

    inputs_dir: Path = Field(
        default=Path("data/raw"),
        description="Per-patient source data root (input_dir = inputs_dir/<patient_id>)",
    )
    assets_dir: Path = Field(
        default=Path("data"),
        description="Shared read-only assets (templates, atlases, ROIs)",
    )
    output_dir: Path = Field(
        default=Path("outputs"),
        description="Root for all workflow outputs (output_dir/<patient_id>)",
    )
    scratch_dir: Optional[Path] = Field(
        default=None, description="Optional temporary/scratch directory override"
    )
    log_dir: Path = Field(default=Path("logs"), description="Directory for log files")
    scripts_dir: Optional[Path] = Field(
        default=None,
        description=(
            "Optional directory of user-supplied workflow scripts (*.py). "
            "When set, `thesis list-workflows` scans the directory and lists "
            "any scripts whose @workflow decorator registers successfully."
        ),
    )

    @field_validator("*", mode="before")
    @classmethod
    def convert_to_path(cls, v):
        """Convert string paths to Path objects, expanding ~ and env vars."""
        return to_path(v)


class HardwareConfig(BaseConfig):
    """Configuration for hardware and computational resources.

    Attributes:
        threads: CPU threads available to workflow execution.
        memory_gb: Memory budget in gigabytes.
        gpu_enabled: Whether GPU-aware execution is enabled.
        gpu_device: Optional GPU device index.
        n_gpu_procs: Number of scheduler GPU worker slots.
        n_gpus: Physical GPUs exposed per worker slot.
    """

    threads: int = Field(default=4, ge=1, le=128, description="Number of CPU threads to use")
    memory_gb: int = Field(default=16, ge=1, le=1024, description="Maximum memory in GB")
    gpu_enabled: bool = Field(default=False, description="Whether to use GPU acceleration")
    gpu_device: Optional[int] = Field(default=None, description="GPU device ID (if multiple GPUs)")
    n_gpu_procs: int = Field(
        default=1,
        ge=1,
        description="Nipype GPU worker slots (concurrent GPU nodes allowed by scheduler)",
    )
    n_gpus: int = Field(
        default=1,
        ge=1,
        description="Physical GPUs visible per worker slot",
    )


class AtlasConfig(BaseConfig):
    """Atlas workflow configuration."""

    presence_value: float = Field(
        default=0.10,
        ge=0.0,
        le=1.0,
        description=(
            "Threshold for prob_threshold calculation. A voxel is counted as "
            "'present' when its normalised value exceeds this threshold."
        ),
    )
    cov_mean_threshold_pct: float = Field(
        default=0.01,
        ge=0.0,
        le=1.0,
        description=(
            "Global atlas mean-map percentage used to suppress low-signal voxels "
            "when computing cov. CV is only computed where mean exceeds this "
            "fraction of the atlas mean maximum."
        ),
    )
    normalization_method: NormalizationMethod = Field(
        default=NormalizationMethod.WAYTOTAL,
        description=(
            "Method for normalizing streamline density volumes before atlas "
            "combination. Options: 'waytotal' (FSL ProbTrackX2 — divide by "
            "total streamline count), 'streamline_density' (MRtrix3 — divide "
            "by SIFT2 weight sum or raw count, matching the TDI units), "
            "'max' (divide by maximum voxel value), 'softmax' (probability "
            "distribution summing to 1.0)."
        ),
    )
    tractography_relpath: str = Field(
        default="tractography/probtrackx2",
        description=(
            "Relative path under each patient directory where per-patient "
            "tractography output lives. The atlas workflow expects "
            "`<patient>/<tractography_relpath>/<run>/warped_streamlines/"
            "fdt_paths.nii.gz` (+ `waytotal`). Override to "
            "'tractography/mrtrix3' for the MRtrix3 backend."
        ),
    )
    output_subdir: str = Field(
        default="atlas",
        description=(
            "Subdirectory under the cohort output directory where the "
            "statistical atlas (atlas_mean.nii.gz et al.) is written. "
            "Override per backend (e.g. 'atlas_mrtrix3') when running "
            "probtrackx2 and mrtrix3 into the same outputs/ tree to "
            "avoid clobbering."
        ),
    )


class S3Config(BaseConfig):
    """Configuration for S3 data download.

    Attributes:
        enabled: Whether S3 download is enabled.
        bucket: S3 bucket name.
        region: AWS region.
        prefix: Bucket prefix (e.g., 'HCP_1200').
        cache_policy: Download behavior when files exist locally.
        max_retries: Maximum retry attempts for failed downloads.
        retry_backoff: Exponential backoff multiplier for retries.
        required_patterns: File patterns that must be downloaded.
        optional_patterns: File patterns that should be downloaded if available.
    """

    enabled: bool = Field(
        default=False,
        description="Enable automatic S3 data download for HCP workflows",
    )
    bucket: str = Field(
        default="hcp-openaccess",
        description="S3 bucket name containing HCP data",
    )
    region: str = Field(
        default="us-east-1",
        description="AWS region where the S3 bucket is located",
    )
    prefix: str = Field(
        default="HCP_1200",
        description="S3 key prefix (folder path within bucket)",
    )
    cache_policy: str = Field(
        default="skip_if_exists",
        description=(
            "Cache behavior: 'skip_if_exists' (don't re-download), "
            "'check_size' (re-download if size differs), "
            "'always' (always re-download)"
        ),
    )
    max_retries: int = Field(
        default=3,
        ge=0,
        description="Maximum number of retry attempts for failed downloads",
    )
    retry_backoff: float = Field(
        default=2.0,
        ge=1.0,
        description="Exponential backoff multiplier for retry delays",
    )
    required_patterns: List[str] = Field(
        default=[
            "T1w/Diffusion/data.nii*",
            "T1w/Diffusion/bvals",
            "T1w/Diffusion/bvecs",
            "T1w/Diffusion/nodif_brain_mask.nii*",
            "T1w/Diffusion.bedpostX/merged_*samples.nii.gz",
            "T1w/T1w_acpc_dc_restore_1.25.nii.gz",
        ],
        description="File patterns (glob) that must be downloaded for workflow to succeed",
    )
    optional_patterns: List[str] = Field(
        default=[
            "T1w/T1w_acpc_dc_restore_brain.nii.gz",
            "T1w/brainmask_fs.nii.gz",
        ],
        description="File patterns (glob) to download if available (won't fail if missing)",
    )

    @field_validator("cache_policy")
    @classmethod
    def validate_cache_policy(cls, v: str) -> str:
        """Validate cache policy is one of the allowed values."""
        allowed = {"skip_if_exists", "check_size", "always"}
        if v not in allowed:
            raise ValueError(f"cache_policy must be one of {allowed}, got '{v}'")
        return v


class PreprocessingConfig(BaseConfig):
    """Configuration for preprocessing pipeline.

    Attributes:
        denoise: Whether denoising is enabled.
        bias_correction: Whether bias correction is enabled.
        brain_extraction: Whether brain extraction is enabled.
        brain_extraction_method: Selected extraction backend.
        motion_correction: Whether motion correction is enabled.
        eddy_correction: Whether eddy-current correction is enabled.
        topup: Whether TOPUP distortion correction is enabled.
        acq_params: Optional TOPUP acquisition parameters.
    """

    denoise: bool = Field(default=True, description="Whether to apply denoising")
    bias_correction: bool = Field(
        default=True, description="Whether to apply bias field correction"
    )
    brain_extraction: bool = Field(default=True, description="Whether to perform brain extraction")
    brain_extraction_method: str = Field(
        default="synthstrip", description="Brain extraction method (bet, synthstrip)"
    )
    motion_correction: bool = Field(default=True, description="Whether to apply motion correction")
    eddy_correction: bool = Field(
        default=True, description="Whether to apply eddy current correction"
    )
    topup: bool = Field(default=True, description="Whether to run TOPUP for distortion correction")
    acq_params: Optional[Dict[str, Any]] = Field(
        default=None, description="Acquisition parameters for TOPUP"
    )

    @field_validator("brain_extraction_method")
    @classmethod
    def validate_bet_method(cls, v):
        """Validate brain extraction method."""
        valid_methods = ["bet", "synthstrip", "ants"]
        if v not in valid_methods:
            raise ValueError(f"brain_extraction_method must be one of {valid_methods}")
        return v


class RegistrationViewerConfig(BaseConfig):
    """Configuration for registration QC viewer launching.

    Attributes:
        enabled: Whether viewer launching is enabled.
        backend: Viewer backend name.
        auto_open: Whether the viewer should be launched automatically.
        overlay_opacity: Default overlay opacity for the transformed patient image.
    """

    enabled: bool = Field(
        default=True,
        description="Whether registration QC viewer support is enabled",
    )
    backend: str = Field(default="fsleyes", description="Viewer backend name")
    auto_open: bool = Field(default=False, description="Launch viewer automatically after run")
    overlay_opacity: float = Field(
        default=0.5,
        ge=0.0,
        le=1.0,
        description="Overlay opacity for the transformed patient image",
    )

    @field_validator("backend")
    @classmethod
    def validate_backend(cls, v: str) -> str:
        """Validate the registration QC viewer backend."""
        valid_methods = ["fsleyes", "html"]
        if v not in valid_methods:
            raise ValueError(f"backend must be one of {valid_methods}")
        return v


class FireantsRegistrationConfig(BaseConfig):
    """Configuration for the FireANTs registration backend.

    Attributes:
        device: Torch device string used by FireANTs.
        scales: Multi-resolution registration scales.
        affine_iterations: Iterations per scale for affine registration.
        deformable_iterations: Iterations per scale for deformable registration.
        optimizer: FireANTs optimizer name.
        affine_lr: Learning rate for affine registration.
        deformable_lr: Learning rate for deformable registration.
        cc_kernel_size: Cross-correlation kernel size.
        deformation_type: FireANTs deformation model.
        dtype: Torch dtype name.
        loss_type: Similarity metric for the staged registration (cc/mi/mse/
            fusedcc/fusedmi); auto-fused on CUDA.
        normalize: Min-max normalize intensities to [0, 1] before registration.
    """

    device: str = Field(default="cuda", description="Torch device used by FireANTs")
    scales: List[int] = Field(
        default=[4, 2, 1],
        description="Multi-resolution scales used by FireANTs",
    )
    affine_iterations: List[int] = Field(
        default=[200, 100, 50],
        description="Affine iterations per FireANTs scale",
    )
    deformable_iterations: List[int] = Field(
        default=[200, 100, 25],
        description="Deformable iterations per FireANTs scale",
    )
    optimizer: str = Field(default="Adam", description="FireANTs optimizer name")
    affine_lr: float = Field(default=3e-3, gt=0.0, description="Affine optimizer learning rate")
    deformable_lr: float = Field(
        default=0.5,
        gt=0.0,
        description="Deformable optimizer learning rate",
    )
    cc_kernel_size: int = Field(
        default=5,
        ge=1,
        description="Cross-correlation kernel size",
    )
    loss_type: str = Field(
        default="cc",
        description=(
            "Similarity metric for the rigid/affine/deformable stages: 'cc' "
            "(local normalized cross-correlation, same-modality), 'mi' (mutual "
            "information, robust to contrast differences / cross-dataset), 'mse', "
            "or the GPU-fused variants 'fusedcc'/'fusedmi'. On a CUDA device "
            "'cc'/'mi' are auto-upgraded to their fused equivalents; on CPU the "
            "fused variants fall back to non-fused."
        ),
    )
    normalize: bool = Field(
        default=True,
        description=(
            "Min-max normalize fixed and moving intensities to [0, 1] before "
            "registration (mirrors FireANTs' own template pipeline). Strongly "
            "recommended when the two images come from different sources / "
            "intensity ranges (e.g. MNI vs a study template)."
        ),
    )
    deformation_type: str = Field(
        default="compositive",
        description="FireANTs deformation type",
    )
    dtype: str = Field(default="float32", description="Torch dtype used by FireANTs")
    do_moments: bool = Field(
        default=True,
        description=(
            "Run the FireANTs MomentsRegistration initialization stage "
            "(center-of-mass + principal-axis alignment) before rigid/affine."
        ),
    )
    do_rigid: bool = Field(
        default=True,
        description="Run the FireANTs RigidRegistration stage before affine/deformable.",
    )
    moments_scale: int = Field(
        default=4,
        ge=1,
        description=(
            "Downsampling scale used by the FireANTs MomentsRegistration stage. "
            "FireANTs recommends 4 (the moments init is robust at low resolution)."
        ),
    )
    moments_moments: int = Field(
        default=1,
        ge=1,
        description=(
            "Number of moments used by the FireANTs MomentsRegistration stage. "
            "FireANTs recommends 1 (center-of-mass + first-order); 2 adds a "
            "second-order principal-axis term that is prone to 180-degree flips "
            "on near-symmetric brains."
        ),
    )
    rigid_iterations: List[int] = Field(
        default=[200, 100, 25],
        description="Rigid iterations per FireANTs scale.",
    )
    rigid_lr: float = Field(
        default=3e-2,
        gt=0.0,
        description="Rigid optimizer learning rate.",
    )
    deform_algo: str = Field(
        default="greedy",
        description=(
            "Deformable engine used for SyN-family transforms: 'greedy' (default, "
            "yields an analytic inverse warp) or 'syn' (no analytic-inverse export)."
        ),
    )
    driver: str = Field(
        default="inprocess",
        description=(
            "How the FireANTs staged registration is driven: 'inprocess' (Python "
            "API in the Nipype Function node) or 'cli' (shell out to the "
            "thesis-fireants-register console script)."
        ),
    )
    deformable_max_spacing_mm: Optional[float] = Field(
        default=None,
        gt=0.0,
        description=(
            "Cap the spacing (mm) of the fixed/moving images fed to the deformable "
            "step.  When the template has finer spacing than this cap, both images "
            "are resampled to it before FireANTs runs.  Default 'null' = no cap "
            "(deformable runs at the template's native resolution; best quality, "
            "but the template's full grid must fit on the GPU).  Set to e.g. 1.0 "
            "to trade quality for VRAM headroom."
        ),
    )

    inverse_method: str = Field(
        default="simpleitk",
        description=(
            "How the reverse (template->patient) warp is produced for SyN-family "
            "transforms: 'simpleitk' (default) numerically inverts the forward "
            "displacement field with SimpleITK's InvertDisplacementFieldImageFilter "
            "— ANTs-free, has an explicit convergence tolerance, and works for both "
            "'greedy' and 'syn'; 'fireants' uses FireANTs' own inverse-consistency "
            "solve (greedy only; 'syn' has no analytic inverse). 'simpleitk' is "
            "recommended: it inverts the high-quality forward field directly instead "
            "of re-optimizing an under-constrained reverse warp."
        ),
    )
    inverse_max_iterations: int = Field(
        default=50,
        ge=1,
        description=(
            "Maximum fixed-point iterations for the SimpleITK displacement-field "
            "inversion (inverse_method='simpleitk')."
        ),
    )
    inverse_tolerance: float = Field(
        default=0.01,
        gt=0.0,
        description=(
            "Mean-error tolerance (in displacement units, i.e. mm) for the SimpleITK "
            "displacement-field inversion. Lower = tighter inverse, more iterations."
        ),
    )

    @field_validator("deform_algo")
    @classmethod
    def validate_deform_algo(cls, v: str) -> str:
        """Validate the FireANTs deformable engine."""
        valid_algos = ["greedy", "syn"]
        if v not in valid_algos:
            raise ValueError(f"deform_algo must be one of {valid_algos}")
        return v

    @field_validator("inverse_method")
    @classmethod
    def validate_inverse_method(cls, v: str) -> str:
        """Validate the reverse-warp reconstruction method."""
        valid_methods = ["simpleitk", "fireants"]
        if v not in valid_methods:
            raise ValueError(f"inverse_method must be one of {valid_methods}")
        return v

    @field_validator("driver")
    @classmethod
    def validate_driver(cls, v: str) -> str:
        """Validate the FireANTs driver."""
        valid_drivers = ["inprocess", "cli"]
        if v not in valid_drivers:
            raise ValueError(f"driver must be one of {valid_drivers}")
        return v

    @field_validator("optimizer")
    @classmethod
    def validate_optimizer(cls, v: str) -> str:
        """Validate the FireANTs optimizer."""
        valid_optimizers = ["Adam", "SGD"]
        if v not in valid_optimizers:
            raise ValueError(f"optimizer must be one of {valid_optimizers}")
        return v

    @field_validator("deformation_type")
    @classmethod
    def validate_deformation_type(cls, v: str) -> str:
        """Validate the FireANTs deformation type."""
        valid_deformation_types = ["compositive", "geodesic"]
        if v not in valid_deformation_types:
            raise ValueError(f"deformation_type must be one of {valid_deformation_types}")
        return v

    @field_validator("dtype")
    @classmethod
    def validate_dtype(cls, v: str) -> str:
        """Validate the FireANTs torch dtype."""
        valid_dtypes = ["float16", "float32", "bfloat16"]
        if v not in valid_dtypes:
            raise ValueError(f"dtype must be one of {valid_dtypes}")
        return v

    @field_validator("loss_type")
    @classmethod
    def validate_loss_type(cls, v: str) -> str:
        """Validate the FireANTs similarity metric."""
        valid_loss_types = ["cc", "mi", "mse", "fusedcc", "fusedmi"]
        if v not in valid_loss_types:
            raise ValueError(f"loss_type must be one of {valid_loss_types}")
        return v

    @model_validator(mode="after")
    def validate_scale_lengths(self) -> "FireantsRegistrationConfig":
        """Ensure per-scale iteration lists match the configured scales."""
        n_scales = len(self.scales)
        if len(self.affine_iterations) != n_scales:
            raise ValueError("affine_iterations must have the same length as scales")
        if len(self.deformable_iterations) != n_scales:
            raise ValueError("deformable_iterations must have the same length as scales")
        if len(self.rigid_iterations) != n_scales:
            raise ValueError("rigid_iterations must have the same length as scales")
        return self


class RegistrationJobConfig(BaseConfig):
    """Configuration for a single named registration job.

    A job overrides the shared ``RegistrationConfig`` defaults for one
    patient-to-template registration. All override fields default to ``None``
    (meaning "inherit the shared value"); only ``name`` is required. The
    ``fireants`` field, when set, is a sparse dict merged onto
    ``registration.fireants`` via ``model_copy(update=...)`` (which re-validates).

    Attributes:
        name: Unique job name used in node/output naming and referenced by
            ``transforms.jobs[*].from_registration``.
        method: Optional registration-backend override.
        moving_modality: Optional moving-image modality override.
        moving_image: Optional explicit moving-image path override.
        fixed_image: Optional fixed-image path override.
        interpolation: Optional interpolation-mode override.
        metric: Optional similarity-metric override.
        transform_type: Optional transform-family override.
        use_float: Optional float-precision override.
        fireants: Optional sparse FireANTs override block.
    """

    name: str = Field(description="Unique registration job name used in node/output naming")
    method: Optional[str] = Field(default=None, description="Registration backend override")
    moving_modality: Optional[str] = Field(
        default=None, description="Moving-image modality override (t1, t2)"
    )
    moving_image: Optional[str] = Field(
        default=None, description="Explicit moving-image path override"
    )
    fixed_image: Optional[str] = Field(default=None, description="Fixed template image override")
    interpolation: Optional[str] = Field(default=None, description="Interpolation method override")
    metric: Optional[str] = Field(default=None, description="Similarity metric override")
    transform_type: Optional[str] = Field(
        default=None, description="Transform family override (Rigid, Affine, SyN)"
    )
    use_float: Optional[bool] = Field(
        default=None, description="Float-precision override for registration"
    )
    fireants: Optional[Dict[str, Any]] = Field(
        default=None,
        description=(
            "Sparse FireANTs override block merged onto registration.fireants "
            "(re-validated). Only the keys present here override the shared values."
        ),
    )

    @field_validator("method")
    @classmethod
    def validate_method(cls, v: Optional[str]) -> Optional[str]:
        """Validate the per-job registration method (None inherits the shared value)."""
        if v is not None and v not in REGISTRATION_METHODS:
            raise ValueError(f"method must be one of {REGISTRATION_METHODS}")
        return v

    @field_validator("moving_modality")
    @classmethod
    def validate_moving_modality(cls, v: Optional[str]) -> Optional[str]:
        """Validate the per-job moving-image modality (None inherits the shared value)."""
        if v is not None and v not in REGISTRATION_MOVING_MODALITIES:
            raise ValueError(f"moving_modality must be one of {REGISTRATION_MOVING_MODALITIES}")
        return v

    @field_validator("interpolation")
    @classmethod
    def validate_interpolation(cls, v: Optional[str]) -> Optional[str]:
        """Validate the per-job interpolation method (None inherits the shared value)."""
        if v is not None and v not in REGISTRATION_INTERPOLATIONS:
            raise ValueError(f"interpolation must be one of {REGISTRATION_INTERPOLATIONS}")
        return v

    @field_validator("transform_type")
    @classmethod
    def validate_transform_type(cls, v: Optional[str]) -> Optional[str]:
        """Validate the per-job transform family (None inherits the shared value)."""
        if v is not None and v not in REGISTRATION_TRANSFORM_TYPES:
            raise ValueError(f"transform_type must be one of {REGISTRATION_TRANSFORM_TYPES}")
        return v


class RegistrationConfig(BaseConfig):
    """Configuration for image registration.

    Attributes:
        enabled: Whether the registration workflow is enabled.
        method: Registration backend.
        moving_modality: Structural modality used as the moving image.
        moving_image: Optional explicit moving-image override.
        fixed_image: Template-space fixed image.
        interpolation: Interpolation mode.
        metric: Similarity metric.
        transform_type: Transform family.
        use_float: Whether float precision is used.
        collapse_output_transforms: Whether ANTs should collapse output transforms.
        write_composite_transform: Whether ANTs should emit composite transforms.
        output_subdir: Output subdirectory under the patient output dir.
        fireants: FireANTs backend settings.
        viewer: Registration QC viewer settings.
    """

    enabled: bool = Field(default=True, description="Whether registration is enabled")
    method: str = Field(default="ants", description="Registration method (ants, fireants, fsl)")
    moving_modality: str = Field(
        default="t1", description="Structural modality for the moving image (t1, t2)"
    )
    moving_image: Optional[str] = Field(
        default=None,
        description="Optional explicit moving-image path override",
    )
    fixed_image: str = Field(default="", description="Fixed template image path")
    interpolation: str = Field(default="Linear", description="Interpolation method")
    metric: str = Field(default="MI", description="Similarity metric (MI, CC, etc.)")
    transform_type: str = Field(default="SyN", description="Type of transform (Rigid, Affine, SyN)")
    use_float: bool = Field(default=True, description="Use float precision for registration")
    collapse_output_transforms: bool = Field(
        default=True,
        description="Collapse ANTs output transforms where possible",
    )
    write_composite_transform: bool = Field(
        default=True,
        description="Write ANTs composite transform outputs",
    )
    output_subdir: str = Field(
        default="registration",
        description="Patient output subdirectory for registration artifacts",
    )
    fireants: FireantsRegistrationConfig = Field(
        default_factory=FireantsRegistrationConfig,
        description="FireANTs backend settings",
    )
    viewer: RegistrationViewerConfig = Field(
        default_factory=RegistrationViewerConfig,
        description="Registration QC viewer settings",
    )
    jobs: List["RegistrationJobConfig"] = Field(
        default_factory=list,
        description=(
            "Named registration jobs. When empty, the top-level fields above are "
            "treated as a single implicit job named 'patient_to_template'. Each "
            "job may override the shared defaults and supply a sparse 'fireants' "
            "block merged onto registration.fireants."
        ),
    )

    @field_validator("method")
    @classmethod
    def validate_method(cls, v):
        """Validate registration method."""
        if v not in REGISTRATION_METHODS:
            raise ValueError(f"method must be one of {REGISTRATION_METHODS}")
        return v

    @field_validator("moving_modality")
    @classmethod
    def validate_moving_modality(cls, v: str) -> str:
        """Validate the selected moving-image modality."""
        if v not in REGISTRATION_MOVING_MODALITIES:
            raise ValueError(f"moving_modality must be one of {REGISTRATION_MOVING_MODALITIES}")
        return v

    @field_validator("transform_type")
    @classmethod
    def validate_transform_type(cls, v: str) -> str:
        """Validate the selected transform family."""
        if v not in REGISTRATION_TRANSFORM_TYPES:
            raise ValueError(f"transform_type must be one of {REGISTRATION_TRANSFORM_TYPES}")
        return v

    @model_validator(mode="after")
    def _validate_unique_job_names(self) -> "RegistrationConfig":
        """Ensure all registration job names are unique."""
        names = [job.name for job in self.jobs]
        duplicates = {name for name in names if names.count(name) > 1}
        if duplicates:
            raise ValueError(f"registration.jobs names must be unique; duplicates: {duplicates}")
        return self


class SegmentationConfig(BaseConfig):
    """Configuration for segmentation.

    Attributes:
        method: Segmentation backend.
        tissue_types: Tissue classes to segment.
        create_masks: Whether binary masks are emitted.
        labels: Optional label-name to integer mapping.
    """

    method: str = Field(default="synthseg", description="Segmentation method")
    tissue_types: List[str] = Field(
        default=["GM", "WM", "CSF"], description="Tissue types to segment"
    )
    create_masks: bool = Field(default=True, description="Whether to create binary masks")
    labels: Optional[Dict[str, int]] = Field(
        default=None, description="Label mapping (name -> value)"
    )

    @field_validator("method")
    @classmethod
    def validate_method(cls, v):
        """Validate segmentation method."""
        valid_methods = ["synthseg", "fast", "freesurfer"]
        if v not in valid_methods:
            raise ValueError(f"method must be one of {valid_methods}")
        return v


class AtlasSourceConfig(BaseConfig):
    """Configuration for one atlas label-map source.

    Attributes:
        name: Unique source name.
        roi_file: Template-space atlas image path.
        transform: Optional named transform to use for this source.
        label_file: Optional label CSV path.
        waypoint_labels: ROI extraction mapping for this source.
    """

    name: str = Field(description="Unique atlas source name used in node/output naming")
    roi_file: str = Field(description="Template-space atlas image path")
    transform: Optional[str] = Field(
        default=None,
        description=(
            "Optional named transform from transforms.atlas_transforms " "to use for this atlas"
        ),
    )
    label_file: Optional[str] = Field(
        default=None,
        description="Optional atlas label CSV path; required when label_name entries are used",
    )
    waypoint_labels: Dict[str, Any] = Field(
        default_factory=dict,
        description="ROI label configuration for this atlas source",
    )


class TractographyConfig(BaseConfig):
    """Configuration for tractography.

    Attributes:
        method: Tractography backend.
        run_tractography: Optional workflow-level execution toggle.
        n_samples: Streamlines per seed voxel.
        n_steps: Maximum steps per streamline.
        step_length: Step length in millimetres.
        curvature_threshold: Maximum curvature threshold.
        hemisphere: Which hemisphere(s) to run tractography for.
        atlas_sources: Optional atlas ROI source list.
        roi_labels: Legacy single-atlas ROI configuration.
        synthseg_roi_labels: Optional SynthSeg-derived ROI configuration.
        force_dir: Whether ProbTrackX2 uses ``--forcedir``.
        opd: Whether ProbTrackX2 writes orientation distribution outputs.
    """

    method: str = Field(default="probtrackx2", description="Tractography method")
    run_tractography: Optional[bool] = Field(
        default=None, description="Whether to run tractography in HCP workflow"
    )
    n_samples: int = Field(default=5000, ge=100, description="Number of streamlines per voxel")
    n_steps: int = Field(default=2000, ge=100, description="Maximum number of steps")
    step_length: float = Field(default=0.5, ge=0.1, le=2.0, description="Step length in mm")
    curvature_threshold: float = Field(
        default=0.2, ge=0.0, le=1.0, description="Curvature threshold (0-1)"
    )
    loop_check: bool = Field(
        default=True,
        description="Perform loop checks on paths (slower, but allows lower curvature threshold)",
    )
    dist_thresh: float = Field(
        default=0.0,
        ge=0.0,
        description="Discard samples shorter than this threshold in mm (0 = no limit)",
    )
    fibst: int = Field(
        default=1,
        ge=1,
        description="Starting fibre for tracking (only effective when rand_fib=0)",
    )
    rand_fib: int = Field(
        default=0,
        ge=0,
        le=3,
        description=(
            "Fibre sampling mode: 0=default, 1=random (f>thresh), "
            "2=proportional (f>thresh), 3=all populations"
        ),
    )
    mod_euler: bool = Field(
        default=False,
        description="Use modified Euler streamlining",
    )
    hemisphere: str = Field(
        default="both",
        description=(
            "Which hemisphere(s) to run tractography for: "
            "left, right, both (merged), or both-separately (two independent runs)"
        ),
    )
    mem_gb_gpu: float = Field(
        default=8.0,
        gt=0,
        description="Nipype scheduler memory hint in GB for GPU ProbTrackX2 nodes",
    )
    mem_gb_cpu: float = Field(
        default=4.0,
        gt=0,
        description="Nipype scheduler memory hint in GB for CPU ProbTrackX2 nodes",
    )
    gpu_runtime_env: Optional[Dict[str, str]] = Field(
        default=None,
        description=(
            "Environment variables injected ONLY into the GPU probtrackx2 "
            "subprocess (the Nipype CommandLine node's ``.inputs.environ``), and "
            "thus only into its ``singularity exec --cleanenv`` FSL container. "
            "Keys MUST carry a SINGULARITYENV_/APPTAINERENV_ prefix to survive "
            "``--cleanenv`` and reach inside the container. Never applied to the "
            "ANTs or SynthSeg nodes, and ignored for CPU runs. Example: "
            "{'SINGULARITYENV_LD_PRELOAD': '/path/HAMi-core/build/libvgpu.so', "
            "'SINGULARITYENV_CUDA_DEVICE_MEMORY_LIMIT': '11g'} to cap each "
            "probtrackx2 to 11 GB so several co-locate on one GPU without OOM."
        ),
    )
    use_waypoints: bool = Field(default=True, description="Whether to use waypoint masks")
    waypoint_files: Optional[Dict[str, Path]] = Field(
        default=None, description="Paths to waypoint mask files"
    )
    roi_files: Optional[List[Path]] = Field(
        default=None, description="Template-space ROI files to transform"
    )
    seed_roi: Optional[Path] = Field(
        default=None, description="Seed ROI (patient or template space depending on workflow)"
    )
    waypoint_masks: Optional[List[Path]] = Field(default=None, description="Waypoint mask files")
    atlas_sources: Optional[List[AtlasSourceConfig]] = Field(
        default=None,
        description="Template-space atlas sources, each with its own atlas image and labels",
    )
    roi_labels: Optional[Dict[str, Any]] = Field(
        default=None,
        description="Label-map ROI configuration (label_file, roi_file, waypoint_labels)",
    )
    synthseg_roi_labels: Optional[Dict[str, Any]] = Field(
        default=None,
        description="SynthSeg-sourced ROI label configuration (waypoint_labels with label_values)",
    )
    force_dir: bool = Field(default=True, description="ProbTrackX2 --forcedir flag")
    opd: bool = Field(default=True, description="ProbTrackX2 --opd flag")

    # MRtrix3-specific tckgen parameters (ignored when method != "mrtrix3")
    tckgen_algorithm: str = Field(
        default="iFOD2",
        description="tckgen tracking algorithm (iFOD2, iFOD1, SD_STREAM, Tensor_Det, ...)",
    )
    tckgen_select: int = Field(
        default=5000,
        ge=1,
        description="tckgen target streamline count (-select)",
    )
    tckgen_seeds: int = Field(
        default=5_000_000,
        ge=1,
        description="tckgen seed attempt cap (-seeds); 0 means unlimited",
    )
    tckgen_minlength: float = Field(
        default=30.0,
        ge=0.0,
        description="tckgen minimum streamline length in mm (-minlength)",
    )
    tckgen_maxlength: float = Field(
        default=250.0,
        gt=0.0,
        description="tckgen maximum streamline length in mm (-maxlength)",
    )
    tckgen_backtrack: bool = Field(
        default=True,
        description="tckgen -backtrack flag (improves ACT-based tracking)",
    )
    tckgen_crop_at_gmwmi: bool = Field(
        default=True,
        description="tckgen -crop_at_gmwmi flag (precise endpoint placement)",
    )
    tckgen_cutoff: Optional[float] = Field(
        default=None,
        ge=0.0,
        description="tckgen FOD amplitude cutoff (-cutoff). None uses MRtrix default.",
    )
    response_algorithm: str = Field(
        default="dhollander",
        description="dwi2response algorithm (dhollander | tournier | msmt_5tt | fa | manual)",
    )
    fod_algorithm: str = Field(
        default="msmt_csd",
        description="dwi2fod algorithm (msmt_csd for multi-shell, csd for single-shell)",
    )
    use_mtnormalise: bool = Field(
        default=True,
        description="Run mtnormalise after dwi2fod to correct intensity bias across tissues",
    )
    fivett_algorithm: str = Field(
        default="fsl",
        description="5ttgen backend (fsl | freesurfer | hsvs)",
    )
    mask_source: str = Field(
        default="fsl_nodif",
        description=(
            "Brain-mask source: 'fsl_nodif' reuses the HCP nodif_brain_mask, "
            "'dwi2mask' generates one from the DWI."
        ),
    )
    mask_dilate_voxels: int = Field(
        default=1,
        ge=0,
        description="Number of dilation passes applied to the brain mask (MRtrix ACT recommends 1)",
    )
    use_sift2: bool = Field(
        default=True,
        description="Run tcksift2 after tckgen to weight streamlines for tckmap",
    )
    seed_strategy: str = Field(
        default="roi",
        description=(
            "tckgen seed source (MRtrix3 only): 'roi' seeds from the atlas/"
            "synthseg seed mask (default); 'gmwmi' seeds from the GM-WM interface "
            "via -seed_gmwmi (requires ACT, which is always enabled here)."
        ),
    )
    gmwmi_apply_roi_filters: bool = Field(
        default=True,
        description=(
            "When seed_strategy='gmwmi', still apply any configured waypoint/avoid/"
            "stop/target masks as tckgen -include/-exclude/-mask filters. False "
            "forces pure whole-brain tracking. Ignored when seed_strategy='roi'."
        ),
    )

    @field_validator("method")
    @classmethod
    def validate_method(cls, v):
        """Validate tractography method."""
        valid_methods = ["probtrackx2", "tckgen", "dsi_studio", "mrtrix3"]
        if v not in valid_methods:
            raise ValueError(f"method must be one of {valid_methods}")
        return v

    @field_validator("gpu_runtime_env")
    @classmethod
    def warn_unprefixed_gpu_runtime_env(
        cls, v: Optional[Dict[str, str]]
    ) -> Optional[Dict[str, str]]:
        """Warn on keys that will not survive ``singularity exec --cleanenv``.

        Only ``SINGULARITYENV_``/``APPTAINERENV_``-prefixed variables are
        forwarded into the container past ``--cleanenv``. A bare key (e.g.
        ``LD_PRELOAD``) would be set in the wrapper's host shell and then
        silently stripped, so it would never reach probtrackx2 — a confusing
        no-op we surface here rather than letting it fail at runtime.
        """
        if v:
            bad = [k for k in v if not k.startswith(("SINGULARITYENV_", "APPTAINERENV_"))]
            if bad:
                logger.warning(
                    "tractography.gpu_runtime_env keys {} lack a "
                    "SINGULARITYENV_/APPTAINERENV_ prefix; they will NOT survive "
                    "`singularity exec --cleanenv` and will not reach the "
                    "probtrackx2 container.",
                    bad,
                )
        return v

    @field_validator("response_algorithm")
    @classmethod
    def validate_response_algorithm(cls, v: str) -> str:
        """Validate dwi2response algorithm."""
        valid = ("dhollander", "tournier", "msmt_5tt", "fa", "manual")
        if v not in valid:
            raise ValueError(f"response_algorithm must be one of {valid}")
        return v

    @field_validator("fod_algorithm")
    @classmethod
    def validate_fod_algorithm(cls, v: str) -> str:
        """Validate dwi2fod algorithm."""
        valid = ("msmt_csd", "csd")
        if v not in valid:
            raise ValueError(f"fod_algorithm must be one of {valid}")
        return v

    @field_validator("fivett_algorithm")
    @classmethod
    def validate_fivett_algorithm(cls, v: str) -> str:
        """Validate 5ttgen backend."""
        valid = ("fsl", "freesurfer", "hsvs")
        if v not in valid:
            raise ValueError(f"fivett_algorithm must be one of {valid}")
        return v

    @field_validator("mask_source")
    @classmethod
    def validate_mask_source(cls, v: str) -> str:
        """Validate brain-mask source."""
        valid = ("fsl_nodif", "dwi2mask")
        if v not in valid:
            raise ValueError(f"mask_source must be one of {valid}")
        return v

    @field_validator("seed_strategy")
    @classmethod
    def validate_seed_strategy(cls, v: str) -> str:
        """Validate tckgen seed source."""
        valid = ("roi", "gmwmi")
        if v not in valid:
            raise ValueError(f"seed_strategy must be one of {valid}")
        return v

    @field_validator("hemisphere")
    @classmethod
    def validate_hemisphere(cls, v: str) -> str:
        """Validate hemisphere value."""
        valid = ("left", "right", "both", "both-separately")
        if v not in valid:
            raise ValueError(f"hemisphere must be one of {valid}")
        return v


class HCPConfig(BaseConfig):
    """Configuration for HCP preprocessed data.

    Attributes:
        diffusion_dir: Diffusion subdirectory under the subject input directory.
        bedpostx_dir: BedpostX output directory.
        t1_image: T1-weighted image path.
        t2_image: Optional T2-weighted image path.
        n_fibers: Number of BedpostX fibres.
        mask_name: Default diffusion brain-mask filename.
        mask_path: Optional full mask-path override.
    """

    diffusion_dir: str = Field(
        default="T1w/Diffusion", description="Diffusion data subdirectory relative to input_dir"
    )
    bedpostx_dir: str = Field(
        default="T1w/Diffusion.bedpostX",
        description="BedpostX output directory relative to input_dir",
    )
    t1_image: str = Field(
        default="T1w/T1w_acpc_dc_restore_1.25.nii",
        description="T1 image path relative to input_dir",
    )
    t2_image: Optional[str] = Field(
        default=None,
        description="Optional T2 image path relative to input_dir",
    )
    n_fibers: int = Field(default=3, description="Number of fibers in BedpostX")
    mask_name: str = Field(default="nodif_brain_mask.nii", description="Brain mask filename")
    mask_path: Optional[str] = Field(
        default=None, description="Brain mask path relative to input_dir (overrides default)"
    )


class TransformJobConfig(BaseConfig):
    """Configuration for a single image transformation job.

    A job maps a set of input images to a transformed output using a named
    transform direction from ``TransformsConfig``.  Multiple jobs can be
    defined to process different sets of images (e.g. mean maps, std maps,
    probability maps) in a single workflow run.

    Attributes:
        name: Unique job name used in output filenames and Nipype node names.
        input_files: Explicit paths to the images to transform. Each path may
            contain ``{patient_id}`` which is substituted at runtime.
        direction: Transform direction — ``template_to_patient`` uses
            ``transforms.template_to_patient`` / ``transforms.reference_image``;
            ``patient_to_template`` uses ``transforms.patient_to_template`` /
            ``transforms.template_reference_image``.
        interpolation: ANTs interpolation method applied to every input image.
        output_subdir: Subdirectory under the patient output directory where
            transformed images are written.
    """

    name: str = Field(description="Unique job name used in output filenames and node names")
    input_files: List[str] = Field(
        default_factory=list,
        description=("Explicit paths to images to transform; {patient_id} substituted at runtime"),
    )
    direction: str = Field(
        default="template_to_patient",
        description="Transform direction: template_to_patient or patient_to_template",
    )
    interpolation: str = Field(
        default="Linear",
        description="ANTs interpolation method (Linear, NearestNeighbor, BSpline, etc.)",
    )
    output_subdir: str = Field(
        default="transformed",
        description="Output subdirectory under patient output dir for transformed images",
    )
    reference_image: Optional[str] = Field(
        default=None,
        description=(
            "Per-job reference image override. When set, overrides "
            "transforms.reference_image (for template_to_patient jobs) or "
            "transforms.template_reference_image (for patient_to_template jobs). "
            "Supports {patient_id} substitution. Use this to resample a single "
            "job onto a different grid (e.g. DWI-space mask) without affecting "
            "other jobs in the same workflow."
        ),
    )
    from_registration: Optional[str] = Field(
        default=None,
        description=(
            "Name of a registration job whose produced transforms drive this job, "
            "instead of the static transforms.* chains. Must match an explicit "
            "registration.jobs name or the implicit 'patient_to_template' job. "
            "The job's direction selects the chain: template_to_patient uses the "
            "registration job's reverse chain; patient_to_template uses the forward chain."
        ),
    )

    @field_validator("direction")
    @classmethod
    def validate_direction(cls, v: str) -> str:
        """Validate transform direction."""
        valid = ["template_to_patient", "patient_to_template"]
        if v not in valid:
            raise ValueError(f"direction must be one of {valid}")
        return v

    @field_validator("interpolation")
    @classmethod
    def validate_interpolation(cls, v: str) -> str:
        """Validate ANTs interpolation method."""
        valid = [
            "Linear",
            "NearestNeighbor",
            "BSpline",
            "MultiLabel",
            "Gaussian",
            "CosineWindowedSinc",
            "WelchWindowedSinc",
            "HammingWindowedSinc",
            "LanczosWindowedSinc",
            "GenericLabel",
        ]
        if v not in valid:
            raise ValueError(f"interpolation must be one of {valid}")
        return v


class AtlasTransformConfig(BaseConfig):
    """Named atlas-to-patient transform configuration.

    Attributes:
        template_to_patient: One transform path or an ordered transform chain.
        reference_image: Patient-space reference image for resampling.
    """

    template_to_patient: Union[str, List[str]] = Field(
        default="",
        description=(
            "Full path pattern or ordered list of paths for template→patient "
            "transforms; {patient_id} substituted at runtime"
        ),
    )
    reference_image: str = Field(
        default="",
        description=(
            "Full path pattern to reference image in patient space; "
            "{patient_id} substituted at runtime"
        ),
    )


class TransformsConfig(BaseConfig):
    """Configuration for pre-computed transforms.

    Attributes:
        base_dir: Optional base directory prepended to relative transform paths.
        patient_to_template: Patient-to-template warp path.
        template_to_patient: Template-to-patient transform or transform chain.
        reference_image: Patient-space reference image.
        template_reference_image: Template-space reference image.
        atlas_transforms: Named atlas-specific transform configurations.
        jobs: Transform jobs executed by the standalone ``transform`` workflow.
    """

    base_dir: Optional[str] = Field(
        default=None,
        description="Base directory prepended to any relative path fields at validation time",
    )
    patient_to_template: str = Field(
        default="",
        description="Full path pattern to patient→template warp field;"
        " {patient_id} substituted at runtime",
    )
    template_to_patient: Union[str, List[str]] = Field(
        default="",
        description="Full path pattern or ordered list of template→patient transforms;"
        " {patient_id} substituted at runtime",
    )
    reference_image: str = Field(
        default="",
        description="Full path pattern to reference image in patient space;"
        " {patient_id} substituted at runtime",
    )
    template_reference_image: str = Field(
        default="",
        description="Path to template-space reference image"
        " (defines output grid for patient→template warps)",
    )
    atlas_transforms: Dict[str, AtlasTransformConfig] = Field(
        default_factory=dict,
        description="Named atlas-specific template→patient transforms",
    )
    jobs: List[TransformJobConfig] = Field(
        default_factory=list,
        description=(
            "Transform jobs for the standalone transform and atlas_to_patient workflows. "
            "Each job specifies input images, direction, interpolation, and output location."
        ),
    )

    @model_validator(mode="after")
    def _apply_base_dir(self) -> "TransformsConfig":
        """Prepend base_dir to any relative (non-absolute, non-empty) path fields."""
        if not self.base_dir:
            return self
        from pathlib import PurePosixPath

        base = self.base_dir.rstrip("/")
        _path_fields = (
            "patient_to_template",
            "reference_image",
            "template_reference_image",
        )
        for field_name in _path_fields:
            value = getattr(self, field_name)
            if value and not PurePosixPath(value).is_absolute():
                object.__setattr__(self, field_name, f"{base}/{value}")

        template_to_patient = self.template_to_patient
        if isinstance(template_to_patient, list):
            object.__setattr__(
                self,
                "template_to_patient",
                [
                    f"{base}/{value}" if value and not PurePosixPath(value).is_absolute() else value
                    for value in template_to_patient
                ],
            )
        elif template_to_patient and not PurePosixPath(template_to_patient).is_absolute():
            object.__setattr__(self, "template_to_patient", f"{base}/{template_to_patient}")

        updated_atlas_transforms: Dict[str, AtlasTransformConfig] = {}
        for name, transform in self.atlas_transforms.items():
            template_to_patient = transform.template_to_patient
            reference_image = transform.reference_image
            if isinstance(template_to_patient, list):
                template_to_patient = [
                    f"{base}/{value}" if value and not PurePosixPath(value).is_absolute() else value
                    for value in template_to_patient
                ]
            elif template_to_patient and not PurePosixPath(template_to_patient).is_absolute():
                template_to_patient = f"{base}/{template_to_patient}"
            if reference_image and not PurePosixPath(reference_image).is_absolute():
                reference_image = f"{base}/{reference_image}"
            updated_atlas_transforms[name] = AtlasTransformConfig(
                template_to_patient=template_to_patient,
                reference_image=reference_image,
            )
        object.__setattr__(self, "atlas_transforms", updated_atlas_transforms)
        return self


class ValidationConfig(BaseConfig):
    """Configuration for post-processing validation checks.

    Attributes:
        check_rois: Whether warped ROI validation is enabled.
        min_voxels: Minimum non-zero voxel count per ROI.
        singularity_threshold: Minimum affine determinant magnitude.
        volume_ratio_min: Lower acceptable warped-to-original voxel ratio.
        volume_ratio_max: Upper acceptable warped-to-original voxel ratio.
    """

    check_rois: bool = Field(
        default=False,
        description="Run automated validation on warped ROI masks after transformation",
    )
    min_voxels: int = Field(
        default=10,
        ge=1,
        description="Minimum non-zero voxel count; fewer voxels raises PipelineError",
    )
    singularity_threshold: float = Field(
        default=1e-6,
        gt=0,
        description=(
            "Minimum absolute determinant of the affine 3x3 rotation/scale block; "
            "below this the ROI is rejected as degenerate (prevents probtrackx2 crash)"
        ),
    )
    volume_ratio_min: float = Field(
        default=0.5,
        gt=0.0,
        le=1.0,
        description=(
            "Lower bound for warped-vs-original voxel count ratio; "
            "below this a warning is emitted"
        ),
    )
    volume_ratio_max: float = Field(
        default=1.5,
        ge=1.0,
        description=(
            "Upper bound for warped-vs-original voxel count ratio; "
            "above this a warning is emitted"
        ),
    )


class SynthSegConfig(BaseConfig):
    """Configuration for standalone or embedded SynthSeg execution.

    Attributes:
        t1_image: Optional explicit T1-weighted input path.
        parc: Whether cortical parcellation output is requested.
        robust: Whether robust mode is enabled.
        fast: Whether fast mode is enabled.
        vol: Whether a volumes CSV is written.
        qc: Whether a QC CSV is written.
        crop: Optional crop size.
        cpu: Whether CPU mode is forced.
        threads: CPU thread count used in CPU mode.
    """

    t1_image: Optional[Path] = Field(default=None, description="Input T1 image override")
    parc: bool = Field(default=False, description="Request parcellation output")
    robust: bool = Field(default=False, description="Enable robust mode")
    fast: bool = Field(default=False, description="Enable fast mode")
    vol: bool = Field(default=False, description="Write volumes CSV")
    qc: bool = Field(default=False, description="Write QC CSV")
    crop: Optional[int] = Field(default=None, ge=1, description="Optional crop size")
    cpu: bool = Field(default=False, description="Force CPU execution")
    threads: int = Field(default=1, ge=1, description="CPU thread count when cpu=true")
    gpu_runtime_env: Optional[Dict[str, str]] = Field(
        default=None,
        description=(
            "Environment variables injected ONLY into the GPU SynthSeg subprocess "
            "(the mri_synthseg node's ``.inputs.environ``), and thus only into its "
            "Singularity exec / FreeSurfer container — never the probtrackx2 or ANTs "
            "containers. Ignored in CPU mode. Used to make the container's TensorFlow "
            "see the GPU when the image lacks the CUDA runtime: e.g. {'SINGULARITY_NV': "
            "'1', 'APPTAINER_NV': '1', 'SINGULARITYENV_LD_LIBRARY_PATH': "
            "'/path/cuda118/lib:/.singularity.d/libs', 'SINGULARITYENV_TF_FORCE_GPU_"
            "ALLOW_GROWTH': 'true'} (bare SINGULARITY_NV/APPTAINER_NV are read by the "
            "singularity CLI to enable --nv; SINGULARITYENV_/APPTAINERENV_ keys are "
            "forwarded inside the container past --cleanenv)."
        ),
    )

    @field_validator("t1_image", mode="before")
    @classmethod
    def _convert_t1_to_path(cls, v: Any) -> Optional[Path]:
        """Convert configured T1 image paths to ``Path`` objects."""
        return cast(Optional[Path], to_path(v))


class NipypeConfig(BaseConfig):
    """Configuration for Nipype workflow execution.

    Attributes:
        working_dir: Base working directory.
        crash_dir: Optional crash dump directory.
        plugin: Nipype execution plugin.
        plugin_args: Plugin arguments passed to ``workflow.run``.
        stop_on_first_crash: Whether execution stops after the first crash.
        remove_unnecessary_outputs: Whether intermediate outputs are cleaned.
        keep_inputs: Whether node input files remain in the working directory.
        hash_method: Caching hash strategy.
        use_profiler: Whether the Nipype profiler is enabled.
    """

    working_dir: Path = Field(default=Path("work"), description="Base working directory")
    crash_dir: Optional[Path] = Field(default=None, description="Crash dump directory")
    plugin: str = Field(default="MultiProc", description="Nipype execution plugin")
    plugin_args: Dict[str, Any] = Field(default_factory=dict, description="Plugin arguments")
    stop_on_first_crash: bool = Field(default=False, description="Stop on first crash")
    remove_unnecessary_outputs: bool = Field(default=True, description="Clean intermediate outputs")
    keep_inputs: bool = Field(default=True, description="Keep input files in workdir")
    hash_method: str = Field(default="content", description="Hash method for caching")
    use_profiler: bool = Field(default=False, description="Enable Nipype profiler")

    @field_validator("working_dir", "crash_dir", mode="before")
    @classmethod
    def _convert_to_path(cls, v):
        """Convert string paths to Path objects, expanding ~ and env vars."""
        return to_path(v)

    @field_validator("plugin")
    @classmethod
    def validate_plugin(cls, v: str) -> str:
        """Validate the Nipype execution plugin against the supported allow-list."""
        valid = {"Linear", "MultiProc", "SGE", "PBS", "LSF", "SLURM", "SLURMGraph"}
        if v not in valid:
            raise ValueError(f"plugin must be one of {sorted(valid)}, got '{v}'")
        return v


class QCConfig(BaseConfig):
    """Configuration for QC visualisation outputs.

    Controls automatic generation of quality-control overlay images
    (ROI placement on anatomical backgrounds, track density maps on
    template images) after workflow execution, as well as extended
    QC checks (SynthSeg quality, outlier detection, etc.).

    Attributes:
        generate_overlays: Whether to produce ROI overlay PNGs at
            the end of the HCP workflow.
        track_density_thresholds: Percentile thresholds used when
            rendering track density figures (applied to non-zero
            voxels of `fdt_paths.nii.gz`).
        synthseg_qc_threshold: Minimum acceptable SynthSeg QC score.
            Subjects below this threshold are flagged.
        outlier_sd_threshold: Number of standard deviations from
            the batch mean to flag a subject as an outlier.
    """

    generate_overlays: bool = Field(
        default=False,
        description="Generate ROI overlay PNGs after workflow execution",
    )
    track_density_thresholds: List[float] = Field(
        default=[50.0, 90.0, 99.0],
        description="Percentile thresholds for track density map visualisation",
    )
    synthseg_qc_threshold: float = Field(
        default=0.6,
        description="Minimum SynthSeg QC score (0-1) before flagging",
    )
    outlier_sd_threshold: float = Field(
        default=2.0,
        description="SD threshold for batch outlier detection",
    )

    @field_validator("track_density_thresholds")
    @classmethod
    def validate_thresholds(cls, v: List[float]) -> List[float]:
        """Validate that all thresholds are valid percentiles in (0, 100)."""
        for t in v:
            if not (0.0 < t < 100.0):
                raise ValueError(f"Each threshold must be between 0 and 100 (exclusive), got {t}")
        return sorted(v)


class AtlasQCConfig(BaseConfig):
    """Configuration for atlas workflow QC outputs.

    Attributes:
        enabled: Whether atlas QC generation is enabled.
        generate_group_plots: Whether cohort-level atlas summary plots are requested.
        generate_patient_reports: Whether per-patient atlas QC outputs are requested.
        outlier_sd_threshold: Number of standard deviations from the cohort mean used
            to flag atlas-derived outliers.
        subject_density_threshold: Minimum normalized streamline density for a subject
            voxel to be considered present.
        group_core_threshold: Minimum proportion of subjects that must have a valid
            streamline (above subject_density_threshold) for a voxel to be part of the core.
        leave_one_out: Whether to compute leave-one-out statistics for each subject.
        leave_one_out_min_subjects: Minimum number of subjects required to run
            leave-one-out validation.
        compute_cv: Whether to compute Coefficient of Variation (CV) maps.
        cv_mean_threshold: Fraction of atlas mean maximum required to compute CV.
        compute_log_stats: Whether to compute statistics in log space.
        log_offset: Small offset added to zero values before log transformation.
    """

    enabled: bool = Field(default=False, description="Enable atlas QC generation")
    generate_group_plots: bool = Field(
        default=True,
        description="Generate cohort-level atlas QC plots",
    )
    generate_patient_reports: bool = Field(
        default=False,
        description="Generate per-patient atlas QC outputs",
    )
    outlier_sd_threshold: float = Field(
        default=2.0,
        gt=0.0,
        description="SD threshold for atlas QC outlier detection",
    )
    subject_density_threshold: float = Field(
        default=0.01,
        ge=0.0,
        le=1.0,
        description=(
            "Minimum normalized streamline density for a subject voxel " "to be considered present"
        ),
    )
    group_core_threshold: float = Field(
        default=0.5,
        ge=0.0,
        le=1.0,
        description="Proportion of subjects required for a voxel to be part of the group core",
    )
    leave_one_out: bool = Field(
        default=False,
        description="Compute leave-one-out statistics for each subject",
    )
    leave_one_out_min_subjects: int = Field(
        default=3,
        ge=2,
        description="Minimum number of subjects required to run leave-one-out validation",
    )
    compute_cv: bool = Field(
        default=True,
        description="Compute Coefficient of Variation (CV) maps",
    )
    cv_mean_threshold: float = Field(
        default=0.01,
        ge=0.0,
        le=1.0,
        description="Fraction of atlas mean maximum required to compute CV",
    )
    compute_log_stats: bool = Field(
        default=True,
        description="Compute statistics in log space",
    )
    log_offset: float = Field(
        default=1e-6,
        gt=0.0,
        description="Small offset added to zero values before log transformation",
    )
    output_subdir: str = Field(
        default="atlas_qc",
        description=(
            "Subdirectory under the cohort output directory where atlas QC "
            "artefacts are written. Override per backend (e.g. "
            "'atlas_qc_mrtrix3') to coexist with another backend's QC "
            "under the same outputs/ tree."
        ),
    )


class SideThresholdConfig(BaseConfig):
    """Binarisation threshold for one side (subject or atlas) of a tract comparison.

    Attributes:
        mode: ``"fraction"`` applies ``value * max(volume)`` as the cutoff;
            ``"absolute"`` uses ``value`` directly as a raw voxel-intensity cutoff.
        value: Threshold value. Must be in ``(0, 1)`` when ``mode="fraction"``;
            any positive number when ``mode="absolute"``.
    """

    mode: Literal["fraction", "absolute"] = Field(
        default="fraction",
        description="Thresholding mode: 'fraction' of volume max, or 'absolute' voxel intensity.",
    )
    value: float = Field(
        default=0.05,
        gt=0.0,
        description=(
            "Threshold value. Must be in (0, 1) when mode='fraction'; any positive "
            "number when mode='absolute'."
        ),
    )

    @model_validator(mode="after")
    def _validate_fraction_range(self) -> "SideThresholdConfig":
        if self.mode == "fraction" and not (0.0 < self.value < 1.0):
            raise ValueError("value must be in (0, 1) when mode='fraction'")
        return self


class HcpLooConfig(BaseConfig):
    """Configuration for the tract_similarity_hcp_loo workflow.

    Controls the per-HCP-subject leave-one-out comparison against the cohort
    atlas. LOO is mathematically valid at N=2 (the LOO atlas becomes the
    other subject's volume), but a floor of 3 enforces a meaningful cohort
    reference.

    Attributes:
        minimum_subjects: Minimum cohort size required to run the workflow.
        write_volumes: If True, write the four NIfTI volumes per subject
            (``subject_normalized.nii.gz``, ``atlas_normalized.nii.gz``,
            ``subject_mask.nii.gz``, ``atlas_mask.nii.gz``) alongside
            ``metrics.json``. If False, only ``metrics.json`` is written.
        learned_prediction_relpath: Optional relative path under each subject
            directory to a learned-atlas per-subject prediction NIfTI (template
            space). Empty (default) disables the learned-template 3rd arm.
        learned_support_threshold: Density threshold defining the held-out
            subject's true support for the non-circular learned-arm metrics.
    """

    minimum_subjects: int = Field(
        default=3,
        ge=2,
        description=(
            "Minimum cohort size required to run the LOO comparison. "
            "Mathematically valid at 2; default 3 enforces a meaningful cohort."
        ),
    )
    write_volumes: bool = Field(
        default=True,
        description=(
            "If True, write the four NIfTI volumes per subject alongside "
            "metrics.json. Set to False to skip them (saves disk on large cohorts)."
        ),
    )
    learned_prediction_relpath: str = Field(
        default="",
        description=(
            "Optional relative path under each subject directory to a learned-"
            "atlas per-subject prediction NIfTI (template space). Empty disables "
            "the learned-template 3rd arm."
        ),
    )
    learned_support_threshold: float = Field(
        default=0.0,
        ge=0.0,
        description=(
            "Density threshold defining the held-out subject's true support for "
            "the non-circular learned-arm metrics."
        ),
    )


class TractSimilarityConfig(BaseConfig):
    """Configuration for the tract_similarity analysis workflow.

    Controls the per-patient comparison between the native-space probtrackx2
    tract density and the warped cohort mean atlas, plus the cohort-level
    aggregation of those per-patient metrics.

    Attributes:
        fdt_relpath: Relative path under the patient output directory to the
            native-space ``fdt_paths.nii.gz``.
        waytotal_relpath: Relative path to the ``waytotal`` text file.
        atlas_relpath: Relative path (or glob pattern) under the patient output
            directory that points at the warped atlas mean volume.
        subject_threshold: Binarisation threshold applied to the subject's
            probtrackx2 volume before overlap / distance metrics.
        atlas_threshold: Binarisation threshold applied to the warped atlas
            volume before overlap / distance metrics.
        n_bins: Histogram bin count for normalised mutual information.
        output_subdir: Subdirectory under the patient output directory where
            ``metrics.json`` is written.
        cohort_output_subdir: Subdirectory under the cohort output directory
            where the aggregated summary is written.
    """

    probtrackx_relpath: str = Field(
        default="tractography/probtrackx2",
        description=(
            "Relative path to the probtrackx2 output directory. Auto-detects "
            "single-run vs hemisphere-split (left/ + right/) layout and sums "
            "the waytotal-normalised volumes across hemispheres, matching how "
            "the atlas workflow builds its cohort stack."
        ),
    )
    fdt_name: str = Field(
        default="fdt_paths.nii.gz",
        description="Filename of the probtrackx2 density volume within each run dir",
    )
    waytotal_name: str = Field(
        default="waytotal",
        description="Filename of the probtrackx2 waytotal text file within each run dir",
    )
    atlas_relpath: str = Field(
        default="atlas_in_patient_space/atlas_mean*.nii.gz",
        description=(
            "Relative path or glob pattern for the warped cohort mean atlas "
            "in the patient's native space"
        ),
    )
    subject_threshold: SideThresholdConfig = Field(
        default_factory=SideThresholdConfig,
        description="Binarisation threshold for the subject's probtrackx2 volume.",
    )
    atlas_threshold: SideThresholdConfig = Field(
        default_factory=SideThresholdConfig,
        description="Binarisation threshold for the warped atlas volume.",
    )
    n_bins: int = Field(
        default=64,
        ge=8,
        le=1024,
        description="Number of histogram bins for normalised mutual information",
    )
    output_subdir: str = Field(
        default="tract_similarity",
        description="Patient-level output subdirectory for metrics.json",
    )
    cohort_output_subdir: str = Field(
        default="cohort/tract_similarity",
        description="Cohort-level output subdirectory for aggregated metrics",
    )
    hcp_loo: HcpLooConfig = Field(
        default_factory=HcpLooConfig,
        description=(
            "Per-HCP-subject leave-one-out comparison against the cohort atlas. "
            "See HcpLooConfig for details."
        ),
    )


class ThresholdGridConfig(BaseConfig):
    """Grid specification for one binarisation-threshold axis of the sweep.

    Provide either ``start``/``stop``/``step`` (inclusive endpoint within float
    tolerance) or an explicit ``values`` list. All entries must lie in
    ``(0, 1)`` because the sweep operates in ``"fraction"`` mode.
    """

    start: Optional[float] = Field(default=0.05, gt=0.0, lt=1.0)
    stop: Optional[float] = Field(default=0.50, gt=0.0, lt=1.0)
    step: Optional[float] = Field(default=0.025, gt=0.0)
    values: Optional[List[float]] = Field(default=None)

    @model_validator(mode="after")
    def _one_form(self) -> "ThresholdGridConfig":
        if self.values is not None:
            if len(self.values) == 0 or not all(0.0 < v < 1.0 for v in self.values):
                raise ValueError("values must be a non-empty list of floats in (0, 1)")
            return self
        if self.start is None or self.stop is None or self.step is None:
            raise ValueError("specify values OR start/stop/step")
        if self.stop <= self.start:
            raise ValueError("stop must be > start")
        return self


class TractSimilaritySweepConfig(BaseConfig):
    """Configuration for the ``tract_similarity_sweep`` cohort grid search.

    Sweeps the ``subject_threshold.value`` and ``atlas_threshold.value`` knobs
    (both in ``"fraction"`` mode) and reports the cell that maximises the
    chosen aggregation of Dice across the cohort.
    """

    subject_threshold_grid: ThresholdGridConfig = Field(
        default_factory=ThresholdGridConfig,
        description="Grid over subject_threshold.value (fraction-of-max).",
    )
    atlas_threshold_grid: ThresholdGridConfig = Field(
        default_factory=ThresholdGridConfig,
        description="Grid over atlas_threshold.value (fraction-of-max).",
    )
    aggregation: Literal["mean", "median"] = Field(
        default="mean",
        description="Cohort aggregation used to pick the best grid cell.",
    )
    output_subdir: str = Field(
        default="tract_similarity_sweep",
        description="Sweep output subdirectory under the cohort output root.",
    )
    emit_heatmap: bool = Field(
        default=True,
        description="Emit a matplotlib heatmap (PNG) of the aggregated grid.",
    )


class OutputSettingsConfig(BaseConfig):
    """Configuration for CLI output behavior (YAML-level defaults).

    CLI flags ``-v``, ``-q``, ``--summary``, ``--no-progress`` override
    values set here at runtime.

    Attributes:
        verbosity: Default verbosity level (``"quiet"``, ``"normal"``, ``"verbose"``).
        summary: Summary detail (``"off"``, ``"compact"``, ``"full"``).
        progress: Whether to show progress bars/spinners.
            ``"auto"`` enables progress for TTY and disables for pipes/CI.
        output_format: Output format (``"human"``, ``"json"``).
    """

    verbosity: str = Field(
        default="normal",
        description="Default verbosity: quiet | normal | verbose",
    )
    summary: str = Field(
        default="compact",
        description="Summary detail: off | compact | full",
    )
    progress: str = Field(
        default="auto",
        description="Progress UI: auto | on | off",
    )
    output_format: str = Field(
        default="human",
        description="Output format: human | json",
    )

    @field_validator("verbosity")
    @classmethod
    def validate_verbosity(cls, v: str) -> str:
        """Validate verbosity value."""
        valid = ("quiet", "normal", "verbose")
        if v not in valid:
            raise ValueError(f"verbosity must be one of {valid}")
        return v

    @field_validator("summary")
    @classmethod
    def validate_summary(cls, v: str) -> str:
        """Validate summary value."""
        valid = ("off", "compact", "full")
        if v not in valid:
            raise ValueError(f"summary must be one of {valid}")
        return v

    @field_validator("progress")
    @classmethod
    def validate_progress(cls, v: str) -> str:
        """Validate progress value."""
        valid = ("auto", "on", "off")
        if v not in valid:
            raise ValueError(f"progress must be one of {valid}")
        return v

    @field_validator("output_format")
    @classmethod
    def validate_output_format(cls, v: str) -> str:
        """Validate output format value."""
        valid = ("human", "json")
        if v not in valid:
            raise ValueError(f"output_format must be one of {valid}")
        return v


class SLURMConfig(BaseConfig):
    """Configuration for SLURM-based grouped stage submission.

    When ``enabled`` is ``False`` (the default) the SLURM submit path is never
    taken and pipeline behavior is identical to local execution.  The walltime
    regex validator, plugin allow-list, and per-group sbatch directive
    derivation are added in a later phase; this model only needs to *exist* so
    that ``cfg.slurm.enabled`` never raises ``AttributeError``.

    Attributes:
        enabled: Whether grouped SLURM submission is used.
        account: SLURM account/charge string (``--account``).
        cpu_partition: Partition for CPU-only stage groups.
        gpu_partition: Default (short) GPU partition.
        gpu_partition_long: Long-running GPU partition (``None`` to disable).
        gpu_short_max_minutes: ``est_minutes`` above this route to the long
            GPU partition.
        walltime_cpu: Walltime (``HH:MM:SS``) for CPU groups.
        walltime_gpu: Walltime (``HH:MM:SS``) for short GPU groups.
        walltime_gpu_long: Walltime (``HH:MM:SS``) for long GPU groups.
        max_concurrent_gpu_jobs: In-flight GPU job cap (matches physical GPUs).
        max_jobs_per_wave: Maximum in-flight jobs per submission wave.
        submit_dir: Shared-FS directory for sbatch scripts / logs / manifest.
        apptainer_image: Path to the Apptainer image used by stage jobs.
        env_setup: Path to an environment-setup script sourced before runs.
        sbatch_extra_args: Extra sbatch args keyed by compute class
            (e.g. ``{"cpu": [...], "gpu": [...]}``).
    """

    enabled: bool = Field(default=False, description="Enable grouped SLURM submission")
    account: str = Field(default="", description="SLURM account/charge string")
    cpu_partition: str = Field(default="CPU_short", description="Partition for CPU-only groups")
    gpu_partition: str = Field(default="GPU_short", description="Default (short) GPU partition")
    gpu_partition_long: Optional[str] = Field(
        default="GPU_long", description="Long-running GPU partition (None to disable)"
    )
    gpu_short_max_minutes: int = Field(
        default=120, description="est_minutes above this route to the long GPU partition"
    )
    walltime_cpu: str = Field(default="04:00:00", description="Walltime for CPU groups")
    walltime_gpu: str = Field(default="04:00:00", description="Walltime for short GPU groups")
    walltime_gpu_long: str = Field(default="24:00:00", description="Walltime for long GPU groups")
    max_concurrent_gpu_jobs: int = Field(
        default=6, description="In-flight GPU job cap (matches physical GPUs)"
    )
    max_jobs_per_wave: int = Field(
        default=1800, ge=1, le=2000, description="Maximum in-flight jobs per submission wave"
    )
    submit_dir: Optional[Path] = Field(
        default=None, description="Shared-FS directory for sbatch scripts, logs, and manifest"
    )
    apptainer_image: Optional[Path] = Field(
        default=None, description="Path to the Apptainer image used by stage jobs"
    )
    env_setup: Optional[Path] = Field(
        default=None, description="Path to an environment-setup script sourced before runs"
    )
    sbatch_extra_args: Dict[str, List[str]] = Field(
        default_factory=dict,
        description="Extra sbatch args keyed by compute class (e.g. cpu/gpu)",
    )

    @field_validator("submit_dir", "apptainer_image", "env_setup", mode="before")
    @classmethod
    def _convert_to_path(cls, v):
        """Convert string paths to Path objects, expanding ~ and env vars."""
        return to_path(v)

    @field_validator("walltime_cpu", "walltime_gpu", "walltime_gpu_long")
    @classmethod
    def _validate_walltime(cls, v: str) -> str:
        """Validate SLURM walltimes have an ``H+:MM:SS`` shape."""
        if not re.match(r"^\d+:\d{2}:\d{2}$", v):
            raise ValueError(f"walltime must match H+:MM:SS (e.g. '04:00:00'), got '{v}'")
        return v


class PipelineConfig(BaseModel):
    """Root configuration model for thesis pipelines.

    Aggregates all sub-configurations and provides convenience methods.

    Attributes:
        paths: Filesystem paths.
        hardware: Compute resources.
        atlas: Atlas generation settings.
        s3: S3 data download settings (optional).
        preprocessing: Generic preprocessing steps.
        preprocess: Workflow-specific raw-to-HCP preprocessing settings.
        registration: Image registration.
        segmentation: Brain segmentation.
        tractography: Tractography parameters.
        hcp: HCP-specific overrides.
        transforms: Pre-computed transform settings.
        nipype: Nipype execution settings.
        validation: ROI validation settings.
        qc: QC generation settings.
        atlas_qc: Atlas QC generation settings.
        synthseg: Optional SynthSeg-specific overrides.
        output: CLI output defaults.
        patient_id: Optional patient identifier.
        protocol: Optional protocol name.
    """

    model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")

    paths: PathConfig = Field(default_factory=PathConfig)
    hardware: HardwareConfig = Field(default_factory=HardwareConfig)
    atlas: AtlasConfig = Field(default_factory=AtlasConfig)
    s3: Optional[S3Config] = Field(
        default=None,
        description="S3 data download configuration (optional, for cloud deployments)",
    )
    preprocessing: PreprocessingConfig = Field(default_factory=PreprocessingConfig)
    preprocess: Dict[str, Any] = Field(
        default_factory=dict,
        description="Workflow-specific config for the preprocess workflow",
    )
    mrtrix3: Dict[str, Any] = Field(
        default_factory=dict,
        description="Workflow-specific config for the mrtrix3 workflow",
    )
    registration: RegistrationConfig = Field(default_factory=RegistrationConfig)
    segmentation: SegmentationConfig = Field(default_factory=SegmentationConfig)
    tractography: TractographyConfig = Field(default_factory=TractographyConfig)
    hcp: HCPConfig = Field(default_factory=HCPConfig)
    transforms: TransformsConfig = Field(default_factory=TransformsConfig)
    nipype: NipypeConfig = Field(default_factory=NipypeConfig)
    slurm: SLURMConfig = Field(default_factory=SLURMConfig)
    validation: ValidationConfig = Field(default_factory=ValidationConfig)
    qc: QCConfig = Field(
        default_factory=QCConfig,
        description="QC visualisation settings (ROI overlays, track density figures)",
    )
    atlas_qc: AtlasQCConfig = Field(
        default_factory=AtlasQCConfig,
        description="Atlas QC settings for cohort-level and per-patient atlas outputs",
    )
    synthseg: Optional[SynthSegConfig] = Field(
        default=None,
        description="SynthSeg-specific config overrides (e.g. t1_image)",
    )
    output: OutputSettingsConfig = Field(
        default_factory=OutputSettingsConfig,
        description="CLI output behavior defaults (verbosity, summary, progress, format)",
    )
    tract_similarity: TractSimilarityConfig = Field(
        default_factory=TractSimilarityConfig,
        description="Tract similarity analysis settings (per-patient and cohort aggregation)",
    )
    tract_similarity_sweep: TractSimilaritySweepConfig = Field(
        default_factory=TractSimilaritySweepConfig,
        description="Grid-search sweep over tract_similarity binarization thresholds.",
    )

    patient_id: Optional[str] = Field(default=None, description="Patient identifier")
    protocol: Optional[str] = Field(default=None, description="Acquisition protocol")

    @model_validator(mode="before")
    @classmethod
    def _validate_workflow_namespaces(cls, data: Any) -> Any:
        """Validate workflow-registered namespaces and reject unknown keys.

        ``extra="allow"`` on PipelineConfig lets unknown top-level keys reach
        this validator. For each unknown key we look up
        :data:`NAMESPACE_REGISTRY` — keys matching a registered workflow
        namespace are validated against the registered schema; keys with no
        match raise ``ValueError`` to preserve the historical typo-catching
        behaviour of ``extra="forbid"``.
        """
        if not isinstance(data, dict):
            return data
        from thesis.core.config.namespace_registry import NAMESPACE_REGISTRY

        known_static = set(cls.model_fields.keys())
        for key in list(data.keys()):
            if key in known_static:
                continue
            schema = NAMESPACE_REGISTRY.get(key)
            if schema is None:
                raise ValueError(
                    f"Unknown top-level config key '{key}'. "
                    f"Known fields: {sorted(known_static)}. "
                    f"Registered workflow namespaces: {NAMESPACE_REGISTRY.list()}. "
                    "If '" + key + "' is a workflow namespace, make sure the workflow "
                    "module is imported before validating the config."
                )
            raw = data[key] if data[key] is not None else {}
            # Validate once here (surfacing schema errors early) and store the
            # typed instance directly. ``extra="allow"`` +
            # ``arbitrary_types_allowed=True`` pass the instance through core
            # validation unchanged, so the after-validator can reuse it without
            # a second parse.
            data[key] = schema.model_validate(raw)
        return data

    @model_validator(mode="after")
    def _instantiate_workflow_namespaces(self) -> "PipelineConfig":
        """Materialise registered namespaces as their typed Pydantic instances.

        ``extra="allow"`` stores unknown keys as raw values on the model.
        After base validation runs, replace each raw dict with an instance
        of the registered schema so workflow bodies can use attribute
        access (``config.my_wf.threshold``). Namespaces absent from the
        YAML get a default-constructed instance.
        """
        from thesis.core.config.namespace_registry import NAMESPACE_REGISTRY

        known_static = set(type(self).model_fields.keys())
        for ns, schema in NAMESPACE_REGISTRY.items():
            if ns in known_static:
                continue
            raw = getattr(self, ns, None)
            if raw is None:
                object.__setattr__(self, ns, schema())
            elif isinstance(raw, schema):
                # Already validated/instantiated by the before-validator.
                continue
            elif isinstance(raw, dict):
                # Constructed directly (not via the before-validator path),
                # e.g. PipelineConfig(**kwargs) — validate the raw dict.
                object.__setattr__(self, ns, schema.model_validate(raw))
        return self

    @model_validator(mode="after")
    def _validate_from_registration_references(self) -> "PipelineConfig":
        """Ensure every transforms.jobs[*].from_registration resolves to a job.

        A reference must match either an explicit ``registration.jobs`` name or
        the implicit ``patient_to_template`` job that is synthesized from the
        top-level registration fields when ``registration.jobs`` is empty.
        """
        referenced = {
            job.from_registration
            for job in self.transforms.jobs
            if job.from_registration is not None
        }
        if not referenced:
            return self

        if self.registration.jobs:
            known = {job.name for job in self.registration.jobs}
        else:
            known = {"patient_to_template"}

        unknown = referenced - known
        if unknown:
            raise ValueError(
                "transforms.jobs[*].from_registration references unknown registration "
                f"job name(s): {sorted(unknown)}. Known registration job(s): {sorted(known)}."
            )
        return self

    @model_validator(mode="after")
    def _validate_slurm_plugin_consistency(self) -> "PipelineConfig":
        """Reject ``SLURMGraph`` when grouped SLURM submission is enabled.

        The grouped SLURM submit path coalesces nodes into resource-homogeneous
        stage jobs.  Nipype's ``SLURMGraph`` plugin submits one job per node,
        which defeats the job-count budget; point users at grouped submission.
        """
        if self.slurm.enabled and self.nipype.plugin == "SLURMGraph":
            raise ValueError(
                "nipype.plugin='SLURMGraph' is incompatible with slurm.enabled=true; "
                "grouped SLURM submission handles job emission, so set "
                "nipype.plugin to 'MultiProc' (or 'Linear')."
            )
        return self

    @classmethod
    def from_dict(cls, config_dict: Dict[str, Any]) -> "PipelineConfig":
        """
        Create a PipelineConfig from a dictionary.

        Convenience wrapper around ``model_validate`` that provides a
        domain-specific name and ensures the return type is narrowed.

        Args:
            config_dict: Dictionary of configuration values

        Returns:
            Validated PipelineConfig object

        Example:
            >>> config_dict = {"hardware": {"threads": 8}}
            >>> config = PipelineConfig.from_dict(config_dict)
        """
        return cast("PipelineConfig", cls.model_validate(config_dict))

    def to_dict(self) -> Dict[str, Any]:
        """
        Convert config to dictionary.

        Returns:
            Dictionary representation of config

        Example:
            >>> config_dict = config.to_dict()
        """
        return cast(Dict[str, Any], self.model_dump(mode="python", exclude_none=False))

    def merge_with(self, other: Union["PipelineConfig", Dict[str, Any]]) -> "PipelineConfig":
        """
        Merge this config with another config or dict.

        Args:
            other: Another PipelineConfig or dict to merge

        Returns:
            New PipelineConfig with merged values

        Example:
            >>> merged = config.merge_with({"hardware": {"threads": 16}})
        """
        from .loaders import merge_configs

        base_dict = self.to_dict()
        if isinstance(other, PipelineConfig):
            other_dict = other.to_dict()
        else:
            other_dict = other

        merged_dict = merge_configs(base_dict, other_dict)
        return self.__class__.from_dict(merged_dict)
